ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,208 @@
1
+ """Trade SHAP explanation logic.
2
+
3
+ This module provides the TradeShapExplainer class that explains individual trades
4
+ using SHAP values, with O(log n) timestamp alignment and efficient feature extraction.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ import numpy as np
12
+
13
+ from ml4t.diagnostic.evaluation.trade_shap.alignment import TimestampAligner
14
+ from ml4t.diagnostic.evaluation.trade_shap.models import (
15
+ TradeExplainFailure,
16
+ TradeShapExplanation,
17
+ )
18
+
19
+ if TYPE_CHECKING:
20
+ import polars as pl
21
+ from numpy.typing import NDArray
22
+
23
+ from ml4t.diagnostic.evaluation.trade_analysis import TradeMetrics
24
+
25
+
26
+ class TradeShapExplainer:
27
+ """Explains individual trades using SHAP values.
28
+
29
+ Uses TimestampAligner for O(log n) timestamp lookup and extracts
30
+ feature values in a single row read for efficiency.
31
+
32
+ Returns TradeExplainFailure for expected failure cases instead of
33
+ throwing exceptions, enabling clean batch processing.
34
+
35
+ Attributes:
36
+ features_df: Polars DataFrame with timestamp and feature columns
37
+ shap_values: 2D numpy array of SHAP values (n_samples x n_features)
38
+ feature_names: List of feature column names
39
+ aligner: TimestampAligner for fast timestamp lookup
40
+ top_n_features: Number of top features to include in explanation
41
+
42
+ Example:
43
+ >>> explainer = TradeShapExplainer(
44
+ ... features_df=features,
45
+ ... shap_values=shap_values,
46
+ ... feature_names=feature_names,
47
+ ... tolerance_seconds=60.0,
48
+ ... )
49
+ >>> result = explainer.explain(trade)
50
+ >>> if isinstance(result, TradeShapExplanation):
51
+ ... print(result.top_features[:3])
52
+ ... else:
53
+ ... print(f"Failed: {result.reason}")
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ features_df: pl.DataFrame,
59
+ shap_values: NDArray[np.floating[Any]],
60
+ feature_names: list[str],
61
+ tolerance_seconds: float = 0.0,
62
+ top_n_features: int | None = None,
63
+ alignment_mode: str = "entry",
64
+ missing_value_strategy: str = "skip",
65
+ ) -> None:
66
+ """Initialize the explainer.
67
+
68
+ Args:
69
+ features_df: Polars DataFrame with 'timestamp' column and feature columns
70
+ shap_values: SHAP values array (n_samples x n_features)
71
+ feature_names: List of feature column names matching shap_values columns
72
+ tolerance_seconds: Maximum seconds for nearest-match alignment (0 = exact only)
73
+ top_n_features: Number of top features to include (None = all)
74
+ alignment_mode: 'entry' for exact match, 'nearest' for closest within tolerance
75
+ missing_value_strategy: How to handle alignment failures ('error', 'skip', 'zero')
76
+
77
+ Raises:
78
+ ValueError: If shap_values shape doesn't match features_df rows or feature_names
79
+ """
80
+ self.features_df = features_df
81
+ self.shap_values = shap_values
82
+ self.feature_names = feature_names
83
+ self.top_n_features = top_n_features
84
+ self.alignment_mode = alignment_mode
85
+ self.missing_value_strategy = missing_value_strategy
86
+
87
+ # Validate shapes
88
+ n_rows = len(features_df)
89
+ n_features = len(feature_names)
90
+
91
+ if shap_values.shape[0] != n_rows:
92
+ raise ValueError(
93
+ f"SHAP values rows ({shap_values.shape[0]}) != features_df rows ({n_rows})"
94
+ )
95
+ if shap_values.shape[1] != n_features:
96
+ raise ValueError(
97
+ f"SHAP values columns ({shap_values.shape[1]}) != feature_names ({n_features})"
98
+ )
99
+
100
+ # Build aligner with appropriate tolerance
101
+ timestamps = features_df["timestamp"].to_list()
102
+ effective_tolerance = tolerance_seconds if alignment_mode == "nearest" else 0.0
103
+ self.aligner = TimestampAligner.from_datetime_index(
104
+ timestamps, tolerance_seconds=effective_tolerance
105
+ )
106
+
107
+ # Cache feature data as numpy for fast row extraction
108
+ self._feature_matrix = features_df.select(feature_names).to_numpy()
109
+
110
+ def explain(
111
+ self,
112
+ trade: TradeMetrics,
113
+ ) -> TradeShapExplanation | TradeExplainFailure:
114
+ """Explain a single trade.
115
+
116
+ Args:
117
+ trade: Trade to explain (must have timestamp and symbol attributes)
118
+
119
+ Returns:
120
+ TradeShapExplanation on success, TradeExplainFailure on expected failures
121
+ """
122
+ trade_id = f"{trade.symbol}_{trade.timestamp.isoformat()}"
123
+
124
+ # Align to timestamp
125
+ result = self.aligner.align(trade.timestamp)
126
+
127
+ if result.index is None:
128
+ # Handle alignment failure based on strategy
129
+ if self.missing_value_strategy == "error":
130
+ raise ValueError(
131
+ f"Cannot align SHAP values for trade {trade_id}: "
132
+ f"no timestamp within {self.aligner.tolerance_seconds}s "
133
+ f"(nearest is {result.distance_seconds:.1f}s away)"
134
+ )
135
+ elif self.missing_value_strategy == "zero":
136
+ # Return zero SHAP vector
137
+ shap_vector = np.zeros(len(self.feature_names))
138
+ feature_values = dict.fromkeys(self.feature_names, 0.0)
139
+ top_features = [(name, 0.0) for name in self.feature_names]
140
+ return TradeShapExplanation(
141
+ trade_id=trade_id,
142
+ timestamp=trade.timestamp,
143
+ top_features=top_features,
144
+ feature_values=feature_values,
145
+ shap_vector=shap_vector,
146
+ )
147
+ else: # "skip" or default
148
+ return TradeExplainFailure(
149
+ trade_id=trade_id,
150
+ timestamp=trade.timestamp,
151
+ reason="alignment_missing",
152
+ details={
153
+ "alignment_mode": self.alignment_mode,
154
+ "tolerance_seconds": self.aligner.tolerance_seconds,
155
+ "distance_seconds": result.distance_seconds,
156
+ },
157
+ )
158
+
159
+ idx = result.index
160
+
161
+ # Extract SHAP vector for this row
162
+ shap_vector = np.asarray(self.shap_values[idx, :], dtype=np.float64)
163
+
164
+ # Extract feature values in one row read (not per-feature loop)
165
+ feature_row = self._feature_matrix[idx, :]
166
+ feature_values = {
167
+ name: float(val) for name, val in zip(self.feature_names, feature_row, strict=True)
168
+ }
169
+
170
+ # Get top N contributors by absolute SHAP value
171
+ top_n = self.top_n_features if self.top_n_features is not None else len(self.feature_names)
172
+
173
+ # Create (feature_name, shap_value) pairs and sort by |shap|
174
+ feature_shap_pairs = list(zip(self.feature_names, shap_vector.tolist(), strict=True))
175
+ feature_shap_pairs.sort(key=lambda x: abs(x[1]), reverse=True)
176
+ top_features = [(name, float(val)) for name, val in feature_shap_pairs[:top_n]]
177
+
178
+ return TradeShapExplanation(
179
+ trade_id=trade_id,
180
+ timestamp=trade.timestamp,
181
+ top_features=top_features,
182
+ feature_values=feature_values,
183
+ shap_vector=shap_vector,
184
+ )
185
+
186
+ def explain_many(
187
+ self,
188
+ trades: list[TradeMetrics],
189
+ ) -> tuple[list[TradeShapExplanation], list[TradeExplainFailure]]:
190
+ """Explain multiple trades.
191
+
192
+ Args:
193
+ trades: List of trades to explain
194
+
195
+ Returns:
196
+ Tuple of (successful explanations, failures)
197
+ """
198
+ explanations: list[TradeShapExplanation] = []
199
+ failures: list[TradeExplainFailure] = []
200
+
201
+ for trade in trades:
202
+ result = self.explain(trade)
203
+ if isinstance(result, TradeShapExplanation):
204
+ explanations.append(result)
205
+ else:
206
+ failures.append(result)
207
+
208
+ return explanations, failures
@@ -0,0 +1,23 @@
1
+ """Hypothesis generation for trade SHAP error patterns.
2
+
3
+ This package provides template-based hypothesis generation for explaining
4
+ why trading patterns cause losses, with templates stored as YAML data.
5
+ """
6
+
7
+ from ml4t.diagnostic.evaluation.trade_shap.hypotheses.generator import (
8
+ HypothesisConfig,
9
+ HypothesisGenerator,
10
+ )
11
+ from ml4t.diagnostic.evaluation.trade_shap.hypotheses.matcher import (
12
+ Template,
13
+ TemplateMatcher,
14
+ load_templates,
15
+ )
16
+
17
+ __all__ = [
18
+ "HypothesisGenerator",
19
+ "HypothesisConfig",
20
+ "TemplateMatcher",
21
+ "Template",
22
+ "load_templates",
23
+ ]
@@ -0,0 +1,290 @@
1
+ """Hypothesis generator for trade SHAP error patterns.
2
+
3
+ Generates actionable hypotheses and improvement suggestions based on
4
+ template matching against error pattern features.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from typing import TYPE_CHECKING, Any
11
+
12
+ from ml4t.diagnostic.evaluation.trade_shap.hypotheses.matcher import (
13
+ TemplateMatcher,
14
+ load_templates,
15
+ )
16
+
17
+ if TYPE_CHECKING:
18
+ from ml4t.diagnostic.evaluation.trade_shap.models import ErrorPattern
19
+
20
+
21
+ @dataclass
22
+ class HypothesisConfig:
23
+ """Configuration for hypothesis generation.
24
+
25
+ Attributes:
26
+ template_library: Which template library to use ('comprehensive' or 'minimal')
27
+ min_confidence: Minimum confidence threshold for generating hypothesis
28
+ max_actions: Maximum number of actions to include
29
+ """
30
+
31
+ template_library: str = "comprehensive"
32
+ min_confidence: float = 0.5
33
+ max_actions: int = 4
34
+
35
+
36
+ class HypothesisGenerator:
37
+ """Generates hypotheses for error patterns using template matching.
38
+
39
+ Matches error pattern features against a library of templates and
40
+ generates actionable hypotheses about why the pattern causes losses.
41
+
42
+ Attributes:
43
+ config: Hypothesis generation configuration
44
+ matcher: Template matcher
45
+
46
+ Example:
47
+ >>> generator = HypothesisGenerator()
48
+ >>> enriched = generator.generate_hypothesis(error_pattern)
49
+ >>> print(enriched.hypothesis)
50
+ >>> print(enriched.actions)
51
+ """
52
+
53
+ def __init__(self, config: HypothesisConfig | Any | None = None) -> None:
54
+ """Initialize generator.
55
+
56
+ Args:
57
+ config: Hypothesis configuration (uses defaults if None).
58
+ Accepts HypothesisConfig dataclass or HypothesisGenerationConfig Pydantic model.
59
+ """
60
+ # Normalize config to HypothesisConfig dataclass
61
+ self.config = self._normalize_config(config)
62
+
63
+ # Load templates and create matcher
64
+ templates = load_templates(self.config.template_library)
65
+ self.matcher = TemplateMatcher(templates)
66
+
67
+ def _normalize_config(self, config: Any) -> HypothesisConfig:
68
+ """Normalize config to HypothesisConfig dataclass.
69
+
70
+ Supports both HypothesisConfig dataclass and HypothesisGenerationConfig Pydantic model.
71
+ """
72
+ if config is None:
73
+ return HypothesisConfig()
74
+
75
+ if isinstance(config, HypothesisConfig):
76
+ return config
77
+
78
+ # Handle Pydantic HypothesisGenerationConfig or similar
79
+ return HypothesisConfig(
80
+ template_library=getattr(config, "template_library", "comprehensive"),
81
+ min_confidence=getattr(config, "min_confidence", 0.5),
82
+ max_actions=getattr(config, "max_actions", 4),
83
+ )
84
+
85
+ def generate_hypothesis(
86
+ self,
87
+ error_pattern: ErrorPattern,
88
+ feature_names: list[str] | None = None,
89
+ ) -> ErrorPattern:
90
+ """Generate hypothesis for an error pattern.
91
+
92
+ Args:
93
+ error_pattern: Error pattern to analyze
94
+ feature_names: Optional list of all feature names for context
95
+
96
+ Returns:
97
+ ErrorPattern with hypothesis, actions, and confidence fields populated
98
+ """
99
+ from ml4t.diagnostic.evaluation.trade_shap.models import ErrorPattern
100
+
101
+ # Parse top_features into dict format for matcher
102
+ pattern_features = [
103
+ {
104
+ "name": feat[0],
105
+ "mean_shap": feat[1],
106
+ "p_value_t": feat[2],
107
+ "p_value_mw": feat[3],
108
+ "is_significant": feat[4],
109
+ }
110
+ for feat in error_pattern.top_features
111
+ ]
112
+
113
+ # Try to match a template
114
+ match_result = self.matcher.match(pattern_features)
115
+
116
+ if match_result is None or match_result.confidence < self.config.min_confidence:
117
+ # No good match - return pattern unchanged
118
+ return error_pattern
119
+
120
+ # Format hypothesis from template
121
+ hypothesis = self._format_hypothesis(
122
+ match_result.template.hypothesis_template,
123
+ match_result.matched_features,
124
+ )
125
+
126
+ # Get actions (limit to max)
127
+ actions = match_result.template.actions[: self.config.max_actions]
128
+
129
+ # Adjust confidence based on pattern characteristics
130
+ adjusted_confidence = self._adjust_confidence(
131
+ match_result.confidence,
132
+ error_pattern.n_trades,
133
+ error_pattern.separation_score,
134
+ )
135
+
136
+ # Return enriched pattern
137
+ return ErrorPattern(
138
+ cluster_id=error_pattern.cluster_id,
139
+ n_trades=error_pattern.n_trades,
140
+ description=error_pattern.description,
141
+ top_features=error_pattern.top_features,
142
+ separation_score=error_pattern.separation_score,
143
+ distinctiveness=error_pattern.distinctiveness,
144
+ hypothesis=hypothesis,
145
+ actions=actions,
146
+ confidence=adjusted_confidence,
147
+ )
148
+
149
+ def _format_hypothesis(
150
+ self,
151
+ template: str,
152
+ matched_features: list[dict[str, Any]],
153
+ ) -> str:
154
+ """Format hypothesis string from template.
155
+
156
+ Substitutes {feature} placeholder with actual feature name(s).
157
+ """
158
+ if not matched_features:
159
+ return template.replace("{feature}", "the feature")
160
+
161
+ # Use first matched feature name
162
+ feature_name = matched_features[0]["name"]
163
+
164
+ # If multiple significant features, mention them
165
+ sig_features = [f for f in matched_features if f["is_significant"]]
166
+ if len(sig_features) > 1:
167
+ names = [f["name"] for f in sig_features[:2]]
168
+ feature_name = " and ".join(names)
169
+
170
+ return template.replace("{feature}", feature_name)
171
+
172
+ def _adjust_confidence(
173
+ self,
174
+ base_confidence: float,
175
+ n_trades: int,
176
+ separation_score: float,
177
+ ) -> float:
178
+ """Adjust confidence based on pattern characteristics.
179
+
180
+ - More trades = higher confidence (larger sample)
181
+ - Higher separation = higher confidence (more distinct pattern)
182
+ - Very small samples or poor separation get significant penalties
183
+ """
184
+ # Trade count adjustment - penalize small samples heavily
185
+ if n_trades >= 20:
186
+ trade_boost = 0.05
187
+ elif n_trades >= 10:
188
+ trade_boost = 0.02
189
+ elif n_trades >= 5:
190
+ trade_boost = -0.10
191
+ elif n_trades >= 2:
192
+ trade_boost = -0.25
193
+ else:
194
+ # Single trade - very unreliable
195
+ trade_boost = -0.50
196
+
197
+ # Separation score adjustment - penalize poor cluster separation
198
+ if separation_score >= 1.5:
199
+ sep_boost = 0.05
200
+ elif separation_score >= 1.0:
201
+ sep_boost = 0.02
202
+ elif separation_score >= 0.5:
203
+ sep_boost = -0.20 # Moderate separation needs noticeable penalty
204
+ elif separation_score >= 0.3:
205
+ sep_boost = -0.35
206
+ else:
207
+ # Very poor separation - cluster is not distinct
208
+ sep_boost = -0.50
209
+
210
+ adjusted = base_confidence + trade_boost + sep_boost
211
+ return max(0.0, min(1.0, adjusted))
212
+
213
+ def generate_actions(
214
+ self,
215
+ error_pattern: ErrorPattern,
216
+ max_actions: int | None = None,
217
+ ) -> list[dict[str, Any]]:
218
+ """Generate prioritized action suggestions for an error pattern.
219
+
220
+ Args:
221
+ error_pattern: Error pattern with hypothesis
222
+ max_actions: Maximum actions to return (defaults to config)
223
+
224
+ Returns:
225
+ List of action dictionaries with category, description, priority, etc.
226
+ """
227
+ if max_actions is None:
228
+ max_actions = self.config.max_actions
229
+
230
+ if not error_pattern.actions:
231
+ return []
232
+
233
+ # Categorize and prioritize actions
234
+ categorized_actions = []
235
+
236
+ for i, action in enumerate(error_pattern.actions[:max_actions]):
237
+ # Determine category from action text
238
+ category = self._categorize_action(action)
239
+
240
+ # Priority based on position and confidence
241
+ priority = self._determine_priority(i, error_pattern.confidence)
242
+
243
+ categorized_actions.append(
244
+ {
245
+ "category": category,
246
+ "description": action,
247
+ "priority": priority,
248
+ "implementation_difficulty": self._estimate_difficulty(action),
249
+ "rationale": f"Based on pattern: {error_pattern.description}",
250
+ }
251
+ )
252
+
253
+ return categorized_actions
254
+
255
+ def _categorize_action(self, action: str) -> str:
256
+ """Categorize an action based on its text."""
257
+ action_lower = action.lower()
258
+
259
+ if any(word in action_lower for word in ["feature", "indicator", "add"]):
260
+ return "feature_engineering"
261
+ elif any(word in action_lower for word in ["filter", "regime", "threshold"]):
262
+ return "filter_regime"
263
+ elif any(word in action_lower for word in ["size", "position", "stop", "risk"]):
264
+ return "risk_management"
265
+ elif any(word in action_lower for word in ["tune", "parameter", "adjust"]):
266
+ return "model_adjustment"
267
+ else:
268
+ return "general"
269
+
270
+ def _determine_priority(self, position: int, confidence: float | None) -> str:
271
+ """Determine action priority."""
272
+ conf = confidence or 0.5
273
+
274
+ if position == 0 and conf >= 0.7:
275
+ return "high"
276
+ elif position <= 1 and conf >= 0.5:
277
+ return "medium"
278
+ else:
279
+ return "low"
280
+
281
+ def _estimate_difficulty(self, action: str) -> str:
282
+ """Estimate implementation difficulty from action text."""
283
+ action_lower = action.lower()
284
+
285
+ if any(word in action_lower for word in ["implement", "hmm", "model", "ensemble"]):
286
+ return "hard"
287
+ elif any(word in action_lower for word in ["add", "consider", "track"]):
288
+ return "medium"
289
+ else:
290
+ return "easy"