ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,413 @@
1
+ """Pattern characterization with proper statistical testing.
2
+
3
+ This module provides PatternCharacterizer for characterizing error patterns
4
+ identified through clustering, with:
5
+ - Welch's t-test (doesn't assume equal variance)
6
+ - Mann-Whitney U test (non-parametric)
7
+ - Benjamini-Hochberg FDR correction for multiple testing
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass
13
+ from typing import TYPE_CHECKING, Any
14
+
15
+ import numpy as np
16
+ from scipy import stats
17
+
18
+ from ml4t.diagnostic.evaluation.trade_shap.models import ErrorPattern
19
+
20
+ if TYPE_CHECKING:
21
+ from numpy.typing import NDArray
22
+
23
+
24
+ @dataclass
25
+ class CharacterizationConfig:
26
+ """Configuration for pattern characterization.
27
+
28
+ Attributes:
29
+ alpha: Significance level for statistical tests (default: 0.05)
30
+ top_n_features: Number of top features to include in characterization
31
+ use_fdr_correction: Whether to apply Benjamini-Hochberg FDR correction
32
+ min_samples_per_test: Minimum samples needed for each group in t-test
33
+ """
34
+
35
+ alpha: float = 0.05
36
+ top_n_features: int = 5
37
+ use_fdr_correction: bool = True
38
+ min_samples_per_test: int = 3
39
+
40
+
41
+ @dataclass
42
+ class FeatureStatistics:
43
+ """Statistical test results for a single feature.
44
+
45
+ Attributes:
46
+ feature_name: Name of the feature
47
+ mean_shap: Mean SHAP value in the cluster
48
+ mean_shap_other: Mean SHAP value in other clusters
49
+ p_value_t: P-value from Welch's t-test
50
+ p_value_mw: P-value from Mann-Whitney U test
51
+ q_value_t: FDR-corrected p-value (t-test), if correction applied
52
+ q_value_mw: FDR-corrected p-value (MW test), if correction applied
53
+ is_significant: Whether the feature is statistically significant
54
+ """
55
+
56
+ feature_name: str
57
+ mean_shap: float
58
+ mean_shap_other: float
59
+ p_value_t: float
60
+ p_value_mw: float
61
+ q_value_t: float | None = None
62
+ q_value_mw: float | None = None
63
+ is_significant: bool = False
64
+
65
+
66
+ def benjamini_hochberg(
67
+ p_values: list[float], alpha: float = 0.05
68
+ ) -> tuple[list[float], list[bool]]:
69
+ """Apply Benjamini-Hochberg FDR correction to p-values.
70
+
71
+ Args:
72
+ p_values: List of raw p-values
73
+ alpha: Significance level (default: 0.05)
74
+
75
+ Returns:
76
+ Tuple of (q_values, is_significant) where:
77
+ - q_values: FDR-adjusted p-values (monotone)
78
+ - is_significant: Boolean mask for significant results
79
+
80
+ Note:
81
+ BH procedure controls False Discovery Rate (FDR) - the expected
82
+ proportion of false discoveries among rejected hypotheses.
83
+ This is less conservative than Bonferroni correction.
84
+ """
85
+ if not p_values:
86
+ return [], []
87
+
88
+ n = len(p_values)
89
+ p_array = np.asarray(p_values)
90
+
91
+ # Sort p-values and track original order
92
+ sorted_indices = np.argsort(p_array)
93
+ sorted_p = p_array[sorted_indices]
94
+
95
+ # BH adjustment: q_i = min(p_i * n / rank, 1.0)
96
+ # Then enforce monotonicity from largest to smallest
97
+ ranks = np.arange(1, n + 1)
98
+ q_sorted = np.minimum(sorted_p * n / ranks, 1.0)
99
+
100
+ # Enforce monotonicity: q[i] = min(q[i], q[i+1], ..., q[n])
101
+ # Process from end to start
102
+ for i in range(n - 2, -1, -1):
103
+ q_sorted[i] = min(q_sorted[i], q_sorted[i + 1])
104
+
105
+ # Restore original order
106
+ q_values = np.empty(n)
107
+ q_values[sorted_indices] = q_sorted
108
+
109
+ # Determine significance
110
+ is_significant = q_values < alpha
111
+
112
+ return q_values.tolist(), is_significant.tolist()
113
+
114
+
115
+ class PatternCharacterizer:
116
+ """Characterizes error patterns with proper statistical testing.
117
+
118
+ Uses Welch's t-test (doesn't assume equal variance) and Mann-Whitney U test,
119
+ with optional Benjamini-Hochberg FDR correction for multiple testing.
120
+
121
+ Attributes:
122
+ config: Characterization configuration
123
+ feature_names: List of all feature names
124
+
125
+ Example:
126
+ >>> characterizer = PatternCharacterizer(feature_names)
127
+ >>> pattern = characterizer.characterize_cluster(
128
+ ... cluster_shap=cluster_vectors,
129
+ ... other_shap=other_vectors,
130
+ ... cluster_id=0,
131
+ ... )
132
+ >>> print(pattern.top_features)
133
+ """
134
+
135
+ def __init__(
136
+ self,
137
+ feature_names: list[str],
138
+ config: CharacterizationConfig | None = None,
139
+ ) -> None:
140
+ """Initialize characterizer.
141
+
142
+ Args:
143
+ feature_names: List of all feature names
144
+ config: Characterization configuration (uses defaults if None)
145
+ """
146
+ self.feature_names = feature_names
147
+ self.config = config or CharacterizationConfig()
148
+
149
+ def characterize_cluster(
150
+ self,
151
+ cluster_shap: NDArray[np.floating[Any]],
152
+ other_shap: NDArray[np.floating[Any]],
153
+ cluster_id: int,
154
+ centroids: NDArray[np.floating[Any]] | None = None,
155
+ ) -> ErrorPattern:
156
+ """Characterize a single cluster as an error pattern.
157
+
158
+ Args:
159
+ cluster_shap: SHAP vectors for trades in this cluster (n_cluster x n_features)
160
+ other_shap: SHAP vectors for all other trades (n_other x n_features)
161
+ cluster_id: Cluster identifier (0-indexed)
162
+ centroids: Optional cluster centroids for separation score calculation
163
+
164
+ Returns:
165
+ ErrorPattern with statistical characterization
166
+ """
167
+ n_trades = cluster_shap.shape[0]
168
+ n_features = len(self.feature_names)
169
+
170
+ # Compute mean SHAP per feature for this cluster
171
+ mean_shap_cluster = np.mean(cluster_shap, axis=0)
172
+ mean_shap_other = (
173
+ np.mean(other_shap, axis=0) if len(other_shap) > 0 else np.zeros(n_features)
174
+ )
175
+
176
+ # Statistical tests for each feature
177
+ feature_stats = self._compute_feature_statistics(
178
+ cluster_shap, other_shap, mean_shap_cluster, mean_shap_other
179
+ )
180
+
181
+ # Apply FDR correction if configured
182
+ if self.config.use_fdr_correction:
183
+ feature_stats = self._apply_fdr_correction(feature_stats)
184
+
185
+ # Sort by absolute mean SHAP (descending)
186
+ feature_stats.sort(key=lambda x: abs(x.mean_shap), reverse=True)
187
+
188
+ # Take top N
189
+ top_stats = feature_stats[: self.config.top_n_features]
190
+
191
+ # Build top_features tuple list for ErrorPattern
192
+ top_features = [
193
+ (
194
+ fs.feature_name,
195
+ fs.mean_shap,
196
+ fs.p_value_t,
197
+ fs.p_value_mw,
198
+ fs.is_significant,
199
+ )
200
+ for fs in top_stats
201
+ ]
202
+
203
+ # Generate pattern description
204
+ description = self._generate_description(top_stats)
205
+
206
+ # Compute separation and distinctiveness scores
207
+ separation_score = self._compute_separation_score(mean_shap_cluster, centroids, cluster_id)
208
+ distinctiveness = self._compute_distinctiveness(mean_shap_cluster, mean_shap_other)
209
+
210
+ return ErrorPattern(
211
+ cluster_id=cluster_id,
212
+ n_trades=n_trades,
213
+ description=description,
214
+ top_features=top_features,
215
+ separation_score=separation_score,
216
+ distinctiveness=distinctiveness,
217
+ )
218
+
219
+ def _compute_feature_statistics(
220
+ self,
221
+ cluster_shap: NDArray[np.floating[Any]],
222
+ other_shap: NDArray[np.floating[Any]],
223
+ mean_shap_cluster: NDArray[np.floating[Any]],
224
+ mean_shap_other: NDArray[np.floating[Any]],
225
+ ) -> list[FeatureStatistics]:
226
+ """Compute statistical tests for each feature.
227
+
228
+ Uses Welch's t-test (equal_var=False) instead of standard t-test
229
+ to handle unequal variances between groups.
230
+ """
231
+ results = []
232
+
233
+ for idx, feature_name in enumerate(self.feature_names):
234
+ cluster_values = cluster_shap[:, idx]
235
+ other_values = other_shap[:, idx] if len(other_shap) > 0 else np.array([])
236
+
237
+ # Skip if insufficient samples
238
+ if (
239
+ len(cluster_values) < self.config.min_samples_per_test
240
+ or len(other_values) < self.config.min_samples_per_test
241
+ ):
242
+ results.append(
243
+ FeatureStatistics(
244
+ feature_name=feature_name,
245
+ mean_shap=float(mean_shap_cluster[idx]),
246
+ mean_shap_other=float(mean_shap_other[idx]),
247
+ p_value_t=1.0,
248
+ p_value_mw=1.0,
249
+ is_significant=False,
250
+ )
251
+ )
252
+ continue
253
+
254
+ # Welch's t-test (doesn't assume equal variance)
255
+ # This is the key fix: using equal_var=False
256
+ try:
257
+ t_stat, p_value_t = stats.ttest_ind(cluster_values, other_values, equal_var=False)
258
+ p_value_t = float(p_value_t) if not np.isnan(p_value_t) else 1.0
259
+ except Exception:
260
+ p_value_t = 1.0
261
+
262
+ # Mann-Whitney U test (non-parametric)
263
+ try:
264
+ _, p_value_mw = stats.mannwhitneyu(
265
+ cluster_values, other_values, alternative="two-sided"
266
+ )
267
+ p_value_mw = float(p_value_mw) if not np.isnan(p_value_mw) else 1.0
268
+ except ValueError:
269
+ # Can fail if all values are identical
270
+ p_value_mw = 1.0
271
+
272
+ results.append(
273
+ FeatureStatistics(
274
+ feature_name=feature_name,
275
+ mean_shap=float(mean_shap_cluster[idx]),
276
+ mean_shap_other=float(mean_shap_other[idx]),
277
+ p_value_t=p_value_t,
278
+ p_value_mw=p_value_mw,
279
+ # Will be set after FDR correction
280
+ is_significant=False,
281
+ )
282
+ )
283
+
284
+ return results
285
+
286
+ def _apply_fdr_correction(
287
+ self, feature_stats: list[FeatureStatistics]
288
+ ) -> list[FeatureStatistics]:
289
+ """Apply Benjamini-Hochberg FDR correction to all p-values.
290
+
291
+ This corrects for multiple testing across all features, reducing
292
+ false positive rate at the cost of some statistical power.
293
+ """
294
+ if not feature_stats:
295
+ return feature_stats
296
+
297
+ # Collect p-values
298
+ p_values_t = [fs.p_value_t for fs in feature_stats]
299
+ p_values_mw = [fs.p_value_mw for fs in feature_stats]
300
+
301
+ # Apply BH correction
302
+ q_values_t, sig_t = benjamini_hochberg(p_values_t, self.config.alpha)
303
+ q_values_mw, sig_mw = benjamini_hochberg(p_values_mw, self.config.alpha)
304
+
305
+ # Update statistics with corrected values
306
+ corrected = []
307
+ for i, fs in enumerate(feature_stats):
308
+ # Significant if either test rejects after FDR correction
309
+ is_sig = sig_t[i] or sig_mw[i]
310
+
311
+ corrected.append(
312
+ FeatureStatistics(
313
+ feature_name=fs.feature_name,
314
+ mean_shap=fs.mean_shap,
315
+ mean_shap_other=fs.mean_shap_other,
316
+ p_value_t=fs.p_value_t,
317
+ p_value_mw=fs.p_value_mw,
318
+ q_value_t=q_values_t[i],
319
+ q_value_mw=q_values_mw[i],
320
+ is_significant=is_sig,
321
+ )
322
+ )
323
+
324
+ return corrected
325
+
326
+ def _generate_description(self, top_stats: list[FeatureStatistics]) -> str:
327
+ """Generate human-readable pattern description."""
328
+ if not top_stats:
329
+ return "Unknown pattern"
330
+
331
+ # Filter to significant features only
332
+ sig_features = [fs for fs in top_stats if fs.is_significant]
333
+
334
+ # Fall back to top features if none significant
335
+ features_to_use = sig_features[:3] if sig_features else top_stats[:2]
336
+
337
+ components = []
338
+ for fs in features_to_use:
339
+ direction = "High" if fs.mean_shap > 0 else "Low"
340
+ arrow = "↑" if fs.mean_shap > 0 else "↓"
341
+ components.append(f"{direction} {fs.feature_name} ({arrow}{fs.mean_shap:.2f})")
342
+
343
+ if len(components) == 1:
344
+ return f"{components[0]} → Losses"
345
+ return " + ".join(components) + " → Losses"
346
+
347
+ def _compute_separation_score(
348
+ self,
349
+ centroid: NDArray[np.floating[Any]],
350
+ all_centroids: NDArray[np.floating[Any]] | None,
351
+ cluster_id: int,
352
+ ) -> float:
353
+ """Compute separation score (distance to nearest other cluster)."""
354
+ if all_centroids is None or len(all_centroids) <= 1:
355
+ return 0.0
356
+
357
+ min_distance = float("inf")
358
+ for i, other_centroid in enumerate(all_centroids):
359
+ if i != cluster_id:
360
+ distance = float(np.linalg.norm(centroid - other_centroid))
361
+ min_distance = min(min_distance, distance)
362
+
363
+ return min_distance if min_distance != float("inf") else 0.0
364
+
365
+ def _compute_distinctiveness(
366
+ self,
367
+ cluster_centroid: NDArray[np.floating[Any]],
368
+ other_mean: NDArray[np.floating[Any]],
369
+ ) -> float:
370
+ """Compute distinctiveness (ratio of max SHAP vs other clusters)."""
371
+ max_cluster = np.max(np.abs(cluster_centroid))
372
+ max_other = np.max(np.abs(other_mean))
373
+
374
+ if max_other == 0:
375
+ return float(max_cluster) if max_cluster > 0 else 1.0
376
+
377
+ return float(max_cluster / max_other)
378
+
379
+ def characterize_all_clusters(
380
+ self,
381
+ shap_vectors: NDArray[np.floating[Any]],
382
+ cluster_labels: list[int],
383
+ n_clusters: int,
384
+ centroids: NDArray[np.floating[Any]] | None = None,
385
+ ) -> list[ErrorPattern]:
386
+ """Characterize all clusters.
387
+
388
+ Args:
389
+ shap_vectors: All SHAP vectors (n_samples x n_features)
390
+ cluster_labels: Cluster assignment for each sample
391
+ n_clusters: Total number of clusters
392
+ centroids: Optional cluster centroids
393
+
394
+ Returns:
395
+ List of ErrorPattern for each cluster
396
+ """
397
+ labels_array = np.asarray(cluster_labels)
398
+ patterns = []
399
+
400
+ for cluster_id in range(n_clusters):
401
+ mask = labels_array == cluster_id
402
+ cluster_shap = shap_vectors[mask]
403
+ other_shap = shap_vectors[~mask]
404
+
405
+ pattern = self.characterize_cluster(
406
+ cluster_shap=cluster_shap,
407
+ other_shap=other_shap,
408
+ cluster_id=cluster_id,
409
+ centroids=centroids,
410
+ )
411
+ patterns.append(pattern)
412
+
413
+ return patterns
@@ -0,0 +1,302 @@
1
+ """Hierarchical clustering for trade error patterns.
2
+
3
+ Provides clustering of SHAP vectors to identify distinct error patterns,
4
+ with proper handling of small sample sizes.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from typing import TYPE_CHECKING, Any, Literal
11
+
12
+ import numpy as np
13
+
14
+ from ml4t.diagnostic.evaluation.trade_shap.models import ClusteringResult
15
+
16
+ if TYPE_CHECKING:
17
+ from numpy.typing import NDArray
18
+
19
+
20
+ DistanceMetric = Literal["euclidean", "cosine", "correlation", "cityblock"]
21
+ LinkageMethod = Literal["ward", "average", "complete", "single"]
22
+
23
+
24
+ @dataclass
25
+ class ClusteringConfig:
26
+ """Configuration for hierarchical clustering.
27
+
28
+ Attributes:
29
+ distance_metric: Distance metric for pdist ('euclidean', 'cosine', etc.)
30
+ linkage_method: Linkage method for hierarchical clustering
31
+ min_cluster_size: Minimum trades per cluster
32
+ min_trades_for_clustering: Minimum trades required to attempt clustering
33
+ """
34
+
35
+ distance_metric: DistanceMetric = "euclidean"
36
+ linkage_method: LinkageMethod = "ward"
37
+ min_cluster_size: int = 5
38
+ min_trades_for_clustering: int = 10
39
+
40
+
41
+ def find_optimal_clusters(
42
+ linkage_matrix: NDArray[np.floating[Any]],
43
+ n_samples: int,
44
+ min_cluster_size: int = 5,
45
+ ) -> int:
46
+ """Find optimal number of clusters using elbow method.
47
+
48
+ Uses the acceleration of merge distances (second derivative) to find
49
+ the "elbow" point in the dendrogram.
50
+
51
+ Args:
52
+ linkage_matrix: Linkage matrix from hierarchical clustering
53
+ n_samples: Total number of samples
54
+ min_cluster_size: Minimum samples per cluster
55
+
56
+ Returns:
57
+ Optimal number of clusters respecting min_cluster_size constraint
58
+
59
+ Note:
60
+ The key fix here is respecting min_cluster_size even when that means
61
+ returning 1 cluster. Previously, the code would force 2 clusters even
62
+ when there weren't enough samples to support min_cluster_size per cluster.
63
+ """
64
+ # Get merge distances (last column of linkage matrix)
65
+ distances = linkage_matrix[:, 2]
66
+
67
+ # Compute first derivative (rate of change)
68
+ first_deriv = np.diff(distances)
69
+
70
+ # Compute second derivative (acceleration)
71
+ second_deriv = np.diff(first_deriv)
72
+
73
+ # Find elbow: Maximum acceleration point
74
+ if len(second_deriv) > 0:
75
+ elbow_idx = int(np.argmax(second_deriv))
76
+ # Convert index to number of clusters
77
+ # linkage_matrix has (n_samples - 1) rows
78
+ n_clusters = max(1, n_samples - elbow_idx - 2)
79
+ else:
80
+ # Fallback: sqrt(n) heuristic
81
+ n_clusters = max(1, int(np.sqrt(n_samples)))
82
+
83
+ # CRITICAL FIX: Respect min_cluster_size constraint
84
+ # max_clusters is at least 1 to avoid edge case where we'd return 0
85
+ max_clusters = max(1, n_samples // min_cluster_size)
86
+ n_clusters = min(n_clusters, max_clusters)
87
+
88
+ # Only force at least 2 clusters if we have room for them
89
+ # This is the bug fix: don't force 2 if max_clusters < 2
90
+ if max_clusters >= 2:
91
+ n_clusters = max(2, n_clusters)
92
+
93
+ return int(n_clusters)
94
+
95
+
96
+ def compute_cluster_sizes(
97
+ labels: NDArray[np.intp] | list[int],
98
+ n_clusters: int,
99
+ ) -> list[int]:
100
+ """Compute number of samples in each cluster using vectorized bincount.
101
+
102
+ Args:
103
+ labels: Cluster assignment for each sample (0-indexed)
104
+ n_clusters: Total number of clusters
105
+
106
+ Returns:
107
+ List of cluster sizes
108
+ """
109
+ labels_array = np.asarray(labels, dtype=np.intp)
110
+ counts = np.bincount(labels_array, minlength=n_clusters)
111
+ return counts.tolist()
112
+
113
+
114
+ def compute_centroids(
115
+ vectors: NDArray[np.floating[Any]],
116
+ labels: NDArray[np.intp] | list[int],
117
+ n_clusters: int,
118
+ ) -> NDArray[np.floating[Any]]:
119
+ """Compute cluster centroids (mean vector per cluster) using vectorized operations.
120
+
121
+ Args:
122
+ vectors: SHAP vectors of shape (n_samples, n_features)
123
+ labels: Cluster assignment for each sample (0-indexed)
124
+ n_clusters: Total number of clusters
125
+
126
+ Returns:
127
+ Centroids of shape (n_clusters, n_features)
128
+ """
129
+ labels_array = np.asarray(labels, dtype=np.intp)
130
+ n_features = vectors.shape[1]
131
+
132
+ centroids = np.zeros((n_clusters, n_features), dtype=np.float64)
133
+
134
+ for k in range(n_clusters):
135
+ mask = labels_array == k
136
+ if np.any(mask):
137
+ centroids[k] = vectors[mask].mean(axis=0)
138
+
139
+ return centroids
140
+
141
+
142
+ class HierarchicalClusterer:
143
+ """Hierarchical clustering for SHAP vectors.
144
+
145
+ Provides clustering of trade SHAP vectors to identify distinct error patterns,
146
+ with quality metrics and dendrogram support.
147
+
148
+ Attributes:
149
+ config: Clustering configuration
150
+
151
+ Example:
152
+ >>> clusterer = HierarchicalClusterer()
153
+ >>> result = clusterer.cluster(shap_vectors, n_clusters=3)
154
+ >>> print(f"Silhouette: {result.silhouette_score:.3f}")
155
+ """
156
+
157
+ def __init__(self, config: ClusteringConfig | None = None) -> None:
158
+ """Initialize clusterer.
159
+
160
+ Args:
161
+ config: Clustering configuration (uses defaults if None)
162
+ """
163
+ self.config = config or ClusteringConfig()
164
+
165
+ def cluster(
166
+ self,
167
+ vectors: NDArray[np.floating[Any]],
168
+ n_clusters: int | None = None,
169
+ ) -> ClusteringResult:
170
+ """Cluster SHAP vectors using hierarchical clustering.
171
+
172
+ Args:
173
+ vectors: SHAP vectors of shape (n_samples, n_features)
174
+ n_clusters: Number of clusters (auto-determined if None)
175
+
176
+ Returns:
177
+ ClusteringResult with assignments, linkage matrix, and quality metrics
178
+
179
+ Raises:
180
+ ValueError: If insufficient samples or invalid input shape
181
+ ImportError: If scipy is not installed
182
+ """
183
+ # Validate inputs
184
+ if vectors.size == 0:
185
+ raise ValueError("Cannot cluster empty vectors")
186
+
187
+ if vectors.ndim != 2:
188
+ raise ValueError(
189
+ f"vectors must be 2D array (n_samples, n_features), got shape {vectors.shape}"
190
+ )
191
+
192
+ n_samples, n_features = vectors.shape
193
+
194
+ if n_samples < self.config.min_trades_for_clustering:
195
+ raise ValueError(
196
+ f"Insufficient samples for clustering: {n_samples} < "
197
+ f"{self.config.min_trades_for_clustering}"
198
+ )
199
+
200
+ # Import scipy
201
+ try:
202
+ import scipy.cluster.hierarchy as sch
203
+ from scipy.spatial.distance import pdist
204
+ except ImportError as e:
205
+ raise ImportError(
206
+ "scipy required for clustering. Install with: pip install scipy"
207
+ ) from e
208
+
209
+ # Compute pairwise distances
210
+ distances = pdist(vectors, metric=self.config.distance_metric)
211
+
212
+ # Perform hierarchical clustering
213
+ linkage_matrix = sch.linkage(distances, method=self.config.linkage_method)
214
+
215
+ # Determine number of clusters
216
+ if n_clusters is None:
217
+ n_clusters = find_optimal_clusters(
218
+ linkage_matrix, n_samples, self.config.min_cluster_size
219
+ )
220
+
221
+ # Cut dendrogram to get cluster assignments
222
+ labels = sch.fcluster(linkage_matrix, t=n_clusters, criterion="maxclust")
223
+ # fcluster returns 1-indexed labels, convert to 0-indexed
224
+ labels = labels - 1
225
+
226
+ # Compute cluster metrics
227
+ cluster_sizes = compute_cluster_sizes(labels, n_clusters)
228
+ centroids = compute_centroids(vectors, labels, n_clusters)
229
+
230
+ # Compute quality metrics
231
+ silhouette = self._compute_silhouette(vectors, labels)
232
+ davies_bouldin = self._compute_davies_bouldin(vectors, labels)
233
+ calinski_harabasz = self._compute_calinski_harabasz(vectors, labels)
234
+
235
+ return ClusteringResult(
236
+ n_clusters=n_clusters,
237
+ cluster_assignments=labels.tolist(),
238
+ linkage_matrix=linkage_matrix,
239
+ centroids=centroids,
240
+ silhouette_score=silhouette,
241
+ davies_bouldin_score=davies_bouldin,
242
+ calinski_harabasz_score=calinski_harabasz,
243
+ cluster_sizes=cluster_sizes,
244
+ distance_metric=self.config.distance_metric,
245
+ linkage_method=self.config.linkage_method,
246
+ )
247
+
248
+ def _compute_silhouette(
249
+ self,
250
+ vectors: NDArray[np.floating[Any]],
251
+ labels: NDArray[np.intp],
252
+ ) -> float:
253
+ """Compute silhouette score for clustering quality.
254
+
255
+ Returns:
256
+ Silhouette score (-1 to 1, higher is better)
257
+ """
258
+ try:
259
+ from sklearn.metrics import silhouette_score
260
+
261
+ # Need at least 2 clusters for silhouette
262
+ unique_labels = np.unique(labels)
263
+ if len(unique_labels) < 2:
264
+ return 0.0
265
+
266
+ return float(silhouette_score(vectors, labels))
267
+ except ImportError:
268
+ return 0.0
269
+
270
+ def _compute_davies_bouldin(
271
+ self,
272
+ vectors: NDArray[np.floating[Any]],
273
+ labels: NDArray[np.intp],
274
+ ) -> float | None:
275
+ """Compute Davies-Bouldin index (lower is better)."""
276
+ try:
277
+ from sklearn.metrics import davies_bouldin_score
278
+
279
+ unique_labels = np.unique(labels)
280
+ if len(unique_labels) < 2:
281
+ return None
282
+
283
+ return float(davies_bouldin_score(vectors, labels))
284
+ except ImportError:
285
+ return None
286
+
287
+ def _compute_calinski_harabasz(
288
+ self,
289
+ vectors: NDArray[np.floating[Any]],
290
+ labels: NDArray[np.intp],
291
+ ) -> float | None:
292
+ """Compute Calinski-Harabasz score (higher is better)."""
293
+ try:
294
+ from sklearn.metrics import calinski_harabasz_score
295
+
296
+ unique_labels = np.unique(labels)
297
+ if len(unique_labels) < 2:
298
+ return None
299
+
300
+ return float(calinski_harabasz_score(vectors, labels))
301
+ except ImportError:
302
+ return None