ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,386 @@
1
+ """Pydantic models for Trade SHAP diagnostics.
2
+
3
+ This module contains the data models used throughout the Trade SHAP analysis:
4
+ - TradeShapExplanation: SHAP explanation for a single trade
5
+ - ClusteringResult: Result of error pattern clustering
6
+ - ErrorPattern: Characterized error pattern from clustered trades
7
+ - TradeShapResult: Complete result of trade-level SHAP analysis
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from datetime import datetime
13
+ from typing import TYPE_CHECKING, Any
14
+
15
+ import numpy as np
16
+ from numpy.typing import NDArray
17
+ from pydantic import BaseModel, Field
18
+
19
+ if TYPE_CHECKING:
20
+ pass
21
+
22
+
23
+ class TradeExplainFailure(BaseModel):
24
+ """Structured failure result for trade explanation.
25
+
26
+ Used instead of exceptions for expected failure cases (alignment missing,
27
+ feature mismatch, etc.) to enable batch processing without try/except.
28
+
29
+ Attributes:
30
+ trade_id: Unique trade identifier
31
+ timestamp: Trade entry timestamp
32
+ reason: Machine-readable failure reason code
33
+ details: Additional context about the failure
34
+ """
35
+
36
+ trade_id: str = Field(..., description="Unique trade identifier")
37
+ timestamp: datetime = Field(..., description="Trade entry timestamp")
38
+ reason: str = Field(
39
+ ...,
40
+ description="Failure reason: 'alignment_missing', 'shap_error', 'feature_mismatch'",
41
+ )
42
+ details: dict[str, Any] = Field(default_factory=dict, description="Additional failure context")
43
+
44
+
45
+ class TradeShapExplanation(BaseModel):
46
+ """SHAP explanation for a single trade.
47
+
48
+ Contains SHAP attribution details for one trade, including:
49
+ - Top contributing features (sorted by absolute SHAP value)
50
+ - Feature values at trade entry
51
+ - Full SHAP vector for all features
52
+ - Waterfall plot data (future enhancement)
53
+
54
+ Attributes:
55
+ trade_id: Unique trade identifier (symbol_timestamp)
56
+ timestamp: Trade entry timestamp
57
+ top_features: List of (feature_name, shap_value) sorted by |shap_value| descending
58
+ feature_values: Dictionary of feature values at trade entry
59
+ shap_vector: Full SHAP vector for all features (numpy array)
60
+
61
+ Example:
62
+ >>> explanation.top_features[:3]
63
+ [('momentum_20d', 0.342), ('volatility_10d', -0.215), ('rsi_14d', 0.108)]
64
+
65
+ >>> explanation.feature_values['momentum_20d']
66
+ 1.235
67
+
68
+ >>> explanation.shap_vector.shape
69
+ (50,) # 50 features
70
+ """
71
+
72
+ trade_id: str = Field(..., description="Unique trade identifier")
73
+ timestamp: datetime = Field(..., description="Trade entry timestamp")
74
+ top_features: list[tuple[str, float]] = Field(
75
+ ..., description="Top N features by absolute SHAP value (descending)"
76
+ )
77
+ feature_values: dict[str, float] = Field(
78
+ ..., description="Feature values at trade entry timestamp"
79
+ )
80
+ shap_vector: NDArray[np.floating[Any]] = Field(
81
+ ..., description="Full SHAP vector for all features"
82
+ )
83
+
84
+ class Config:
85
+ """Pydantic config."""
86
+
87
+ arbitrary_types_allowed = True
88
+
89
+
90
+ class ClusteringResult(BaseModel):
91
+ """Result of error pattern clustering.
92
+
93
+ Contains cluster assignments, centroids, quality metrics, and linkage matrix
94
+ for dendrogram visualization.
95
+
96
+ Attributes:
97
+ n_clusters: Number of clusters identified
98
+ cluster_assignments: Cluster ID for each trade (0-indexed list)
99
+ linkage_matrix: Scipy linkage matrix for dendrogram plotting
100
+ centroids: Mean SHAP vector for each cluster (shape: n_clusters x n_features)
101
+ silhouette_score: Quality metric (range: -1 to 1, higher is better)
102
+ - 1.0: Perfect separation
103
+ - 0.5: Good separation
104
+ - 0.0: Overlapping clusters
105
+ - <0.0: Poor clustering (mis-assigned trades)
106
+ davies_bouldin_score: Davies-Bouldin Index (lower = better, min: 0)
107
+ - Measures ratio of within-cluster to between-cluster distances
108
+ - < 1.0: Good clustering
109
+ - 1.0-2.0: Acceptable clustering
110
+ - > 2.0: Poor clustering
111
+ calinski_harabasz_score: Calinski-Harabasz Score (higher = better, min: 0)
112
+ - Also known as Variance Ratio Criterion
113
+ - Measures ratio of between-cluster to within-cluster dispersion
114
+ - Higher values indicate better-defined clusters
115
+ cluster_sizes: Number of trades in each cluster
116
+ distance_metric: Distance metric used ('euclidean', 'cosine', etc.)
117
+ linkage_method: Linkage method used ('ward', 'average', 'complete', 'single')
118
+
119
+ Example - Basic inspection:
120
+ >>> result = analyzer.cluster_patterns(shap_vectors)
121
+ >>> print(f"Found {result.n_clusters} clusters")
122
+ >>> print(f"Cluster sizes: {result.cluster_sizes}")
123
+ >>> print(f"Quality (silhouette): {result.silhouette_score:.3f}")
124
+
125
+ Example - Visualize dendrogram:
126
+ >>> from scipy.cluster.hierarchy import dendrogram
127
+ >>> import matplotlib.pyplot as plt
128
+ >>> dendrogram(result.linkage_matrix)
129
+ >>> plt.title("Error Pattern Dendrogram")
130
+ >>> plt.xlabel("Trade Index")
131
+ >>> plt.ylabel("Distance")
132
+ >>> plt.show()
133
+
134
+ Example - Analyze specific cluster:
135
+ >>> cluster_id = 0
136
+ >>> trades_in_cluster = [i for i, c in enumerate(result.cluster_assignments) if c == cluster_id]
137
+ >>> cluster_centroid = result.centroids[cluster_id]
138
+ >>> print(f"Cluster {cluster_id}: {len(trades_in_cluster)} trades")
139
+ >>> print(f"Centroid (mean SHAP): {cluster_centroid}")
140
+
141
+ Note:
142
+ - linkage_matrix can be used directly with scipy.cluster.hierarchy.dendrogram()
143
+ - centroids represent "typical" SHAP pattern for each cluster
144
+ - silhouette_score > 0.5 indicates well-separated clusters
145
+ """
146
+
147
+ n_clusters: int = Field(..., description="Number of clusters identified")
148
+ cluster_assignments: list[int] = Field(..., description="Cluster ID for each trade (0-indexed)")
149
+ linkage_matrix: NDArray[np.floating[Any]] = Field(
150
+ ..., description="Scipy linkage matrix for dendrogram"
151
+ )
152
+ centroids: NDArray[np.floating[Any]] = Field(
153
+ ..., description="Mean SHAP vector per cluster (n_clusters x n_features)"
154
+ )
155
+ silhouette_score: float = Field(
156
+ ..., description="Cluster quality metric (range: -1 to 1, higher is better)"
157
+ )
158
+ davies_bouldin_score: float | None = Field(
159
+ None,
160
+ description="Davies-Bouldin Index (lower = better, min: 0, no upper bound). "
161
+ "Measures ratio of within-cluster to between-cluster distances. "
162
+ "Values < 1.0 indicate good clustering.",
163
+ )
164
+ calinski_harabasz_score: float | None = Field(
165
+ None,
166
+ description="Calinski-Harabasz Score (higher = better, min: 0, no upper bound). "
167
+ "Also known as Variance Ratio Criterion. "
168
+ "Measures ratio of between-cluster to within-cluster dispersion.",
169
+ )
170
+ cluster_sizes: list[int] = Field(..., description="Number of trades per cluster")
171
+ distance_metric: str = Field(..., description="Distance metric used for clustering")
172
+ linkage_method: str = Field(..., description="Linkage method used for clustering")
173
+
174
+ class Config:
175
+ """Pydantic config."""
176
+
177
+ arbitrary_types_allowed = True
178
+
179
+
180
+ class ErrorPattern(BaseModel):
181
+ """Characterized error pattern from clustered trades.
182
+
183
+ Represents a distinct pattern of trading errors identified through SHAP-based
184
+ clustering and statistical characterization. Contains the defining features,
185
+ quality metrics, and (optionally) generated hypotheses and action suggestions.
186
+
187
+ Attributes:
188
+ cluster_id: Unique identifier for this error pattern (0-indexed)
189
+ n_trades: Number of trades exhibiting this pattern
190
+ description: Human-readable pattern description
191
+ Format: "High feature_X (up 0.45) + Low feature_Y (down -0.32) -> Losses"
192
+ top_features: Top contributing SHAP features
193
+ List of (feature_name, mean_shap, p_value_t, p_value_mw, is_significant)
194
+ separation_score: Distance to nearest other cluster (higher = more distinct)
195
+ distinctiveness: Ratio of max SHAP vs other clusters (higher = more unique)
196
+ hypothesis: Optional generated hypothesis about why pattern causes losses
197
+ actions: Optional list of suggested remediation actions
198
+ confidence: Optional confidence score for hypothesis (0-1)
199
+
200
+ Example - Basic pattern:
201
+ >>> pattern = ErrorPattern(
202
+ ... cluster_id=0,
203
+ ... n_trades=15,
204
+ ... description="High momentum (up 0.45) + Low volatility (down -0.32) -> Losses",
205
+ ... top_features=[
206
+ ... ("momentum_20d", 0.45, 0.001, 0.002, True),
207
+ ... ("volatility_10d", -0.32, 0.003, 0.004, True)
208
+ ... ],
209
+ ... separation_score=1.2,
210
+ ... distinctiveness=1.8
211
+ ... )
212
+ >>> print(pattern.summary())
213
+ "Pattern 0: 15 trades - High momentum (up 0.45) + Low volatility (down -0.32) -> Losses"
214
+
215
+ Example - With hypothesis and actions:
216
+ >>> pattern = ErrorPattern(
217
+ ... cluster_id=1,
218
+ ... n_trades=22,
219
+ ... description="High RSI (up 0.38) + High volume (up 0.29) -> Losses",
220
+ ... top_features=[("rsi_14", 0.38, 0.001, 0.001, True)],
221
+ ... separation_score=0.9,
222
+ ... distinctiveness=1.5,
223
+ ... hypothesis="Trades entering overbought conditions with high volume (potential reversals)",
224
+ ... actions=[
225
+ ... "Add overbought filter: skip trades when RSI > 70",
226
+ ... "Consider volume profile: avoid high volume in overbought zones",
227
+ ... "Add mean reversion features to capture reversal dynamics"
228
+ ... ],
229
+ ... confidence=0.85
230
+ ... )
231
+ >>> for action in pattern.actions:
232
+ ... print(f" - {action}")
233
+
234
+ Note:
235
+ - hypothesis, actions, and confidence are populated by HypothesisGenerator
236
+ - top_features are sorted by absolute SHAP value (descending)
237
+ - separation_score and distinctiveness are quality metrics for pattern validation
238
+ """
239
+
240
+ cluster_id: int = Field(..., description="Cluster identifier (0-indexed)", ge=0)
241
+ n_trades: int = Field(..., description="Number of trades in this pattern", gt=0)
242
+ description: str = Field(..., description="Human-readable pattern description", min_length=1)
243
+ top_features: list[tuple[str, float, float, float, bool]] = Field(
244
+ ...,
245
+ description="Top SHAP features: (name, mean_shap, p_value_t, p_value_mw, is_significant)",
246
+ )
247
+ separation_score: float = Field(
248
+ ..., description="Distance to nearest other cluster (higher = better)", ge=0.0
249
+ )
250
+ distinctiveness: float = Field(
251
+ ..., description="Ratio of max SHAP vs other clusters (higher = better)", gt=0.0
252
+ )
253
+ hypothesis: str | None = Field(
254
+ None, description="Generated hypothesis about why this pattern causes losses"
255
+ )
256
+ actions: list[str] | None = Field(
257
+ None, description="Suggested remediation actions for this pattern"
258
+ )
259
+ confidence: float | None = Field(
260
+ None, description="Confidence score for hypothesis (0-1)", ge=0.0, le=1.0
261
+ )
262
+
263
+ def to_dict(self) -> dict[str, Any]:
264
+ """Convert ErrorPattern to dictionary.
265
+
266
+ Returns:
267
+ Dictionary representation suitable for JSON serialization
268
+
269
+ Example:
270
+ >>> pattern_dict = pattern.to_dict()
271
+ >>> import json
272
+ >>> json.dumps(pattern_dict, indent=2)
273
+ """
274
+ return {
275
+ "cluster_id": self.cluster_id,
276
+ "n_trades": self.n_trades,
277
+ "description": self.description,
278
+ "top_features": [
279
+ {
280
+ "feature_name": feat[0],
281
+ "mean_shap": feat[1],
282
+ "p_value_t": feat[2],
283
+ "p_value_mw": feat[3],
284
+ "is_significant": feat[4],
285
+ }
286
+ for feat in self.top_features
287
+ ],
288
+ "separation_score": self.separation_score,
289
+ "distinctiveness": self.distinctiveness,
290
+ "hypothesis": self.hypothesis,
291
+ "actions": self.actions if self.actions else [],
292
+ "confidence": self.confidence,
293
+ }
294
+
295
+ def summary(self, include_actions: bool = False) -> str:
296
+ """Generate human-readable summary of error pattern.
297
+
298
+ Args:
299
+ include_actions: Whether to include action suggestions in summary
300
+
301
+ Returns:
302
+ Formatted summary string
303
+
304
+ Example:
305
+ >>> print(pattern.summary())
306
+ "Pattern 0: 15 trades - High momentum (up 0.45) + Low volatility (down -0.32) -> Losses"
307
+
308
+ >>> print(pattern.summary(include_actions=True))
309
+ '''
310
+ Pattern 0: 15 trades
311
+ Description: High momentum (up 0.45) + Low volatility (down -0.32) -> Losses
312
+ Hypothesis: Trades entering overbought conditions
313
+ Actions:
314
+ - Add overbought filter: skip trades when RSI > 70
315
+ - Consider volume profile
316
+ Confidence: 85%
317
+ '''
318
+ """
319
+ if not include_actions or not self.hypothesis:
320
+ # Simple one-line summary
321
+ return f"Pattern {self.cluster_id}: {self.n_trades} trades - {self.description}"
322
+
323
+ # Detailed multi-line summary with hypothesis and actions
324
+ lines = [
325
+ f"Pattern {self.cluster_id}: {self.n_trades} trades",
326
+ f"Description: {self.description}",
327
+ ]
328
+
329
+ if self.hypothesis:
330
+ lines.append(f"Hypothesis: {self.hypothesis}")
331
+
332
+ if self.actions:
333
+ lines.append("Actions:")
334
+ for action in self.actions:
335
+ lines.append(f" - {action}")
336
+
337
+ if self.confidence is not None:
338
+ lines.append(f"Confidence: {self.confidence:.0%}")
339
+
340
+ return "\n".join(lines)
341
+
342
+ class Config:
343
+ """Pydantic config."""
344
+
345
+ arbitrary_types_allowed = True
346
+
347
+
348
+ class TradeShapResult(BaseModel):
349
+ """Complete result of trade-level SHAP analysis.
350
+
351
+ Contains SHAP explanations for multiple trades, along with error patterns
352
+ and actionable recommendations.
353
+
354
+ Attributes:
355
+ n_trades_analyzed: Total number of trades attempted to analyze
356
+ n_trades_explained: Number of trades successfully explained
357
+ n_trades_failed: Number of trades that failed explanation
358
+ explanations: List of successful TradeShapExplanation objects
359
+ failed_trades: List of (trade_id, error_message) tuples for failed trades
360
+ error_patterns: Identified error patterns from clustering
361
+
362
+ Example:
363
+ >>> result = analyzer.explain_worst_trades(trades, n=20)
364
+ >>> print(f"Success rate: {result.n_trades_explained}/{result.n_trades_analyzed}")
365
+ >>> for explanation in result.explanations:
366
+ ... print(f"Trade {explanation.trade_id}: top feature = {explanation.top_features[0]}")
367
+ """
368
+
369
+ n_trades_analyzed: int = Field(..., description="Total trades analyzed")
370
+ n_trades_explained: int = Field(..., description="Trades successfully explained")
371
+ n_trades_failed: int = Field(..., description="Trades that failed explanation")
372
+ explanations: list[TradeShapExplanation] = Field(
373
+ default_factory=list, description="Successful SHAP explanations"
374
+ )
375
+ failed_trades: list[tuple[str, str]] = Field(
376
+ default_factory=list, description="Failed trades: (trade_id, error_message)"
377
+ )
378
+ error_patterns: list[ErrorPattern] = Field(
379
+ default_factory=list,
380
+ description="Identified error patterns (populated by clustering and characterization)",
381
+ )
382
+
383
+ class Config:
384
+ """Pydantic config."""
385
+
386
+ arbitrary_types_allowed = True
@@ -0,0 +1,116 @@
1
+ """Normalization functions for SHAP vector clustering.
2
+
3
+ Provides L1, L2, and standardization normalization with proper
4
+ handling of edge cases (zero vectors, zero variance).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING, Any, Literal
10
+
11
+ import numpy as np
12
+
13
+ if TYPE_CHECKING:
14
+ from numpy.typing import NDArray
15
+
16
+
17
+ NormalizationType = Literal["l1", "l2", "standardize", "none"]
18
+
19
+
20
+ def normalize_l1(vectors: NDArray[np.floating[Any]]) -> NDArray[np.floating[Any]]:
21
+ """L1 normalization: Scale each row by sum of absolute values.
22
+
23
+ Args:
24
+ vectors: Input vectors of shape (n_samples, n_features)
25
+
26
+ Returns:
27
+ L1-normalized vectors where each row sums to 1.0 (in absolute terms)
28
+
29
+ Note:
30
+ Zero vectors are returned unchanged (no division by zero)
31
+ """
32
+ l1_norms = np.sum(np.abs(vectors), axis=1, keepdims=True)
33
+ l1_norms = np.where(l1_norms == 0, 1.0, l1_norms)
34
+ return vectors / l1_norms
35
+
36
+
37
+ def normalize_l2(vectors: NDArray[np.floating[Any]]) -> NDArray[np.floating[Any]]:
38
+ """L2 normalization: Scale each row to unit Euclidean norm.
39
+
40
+ Args:
41
+ vectors: Input vectors of shape (n_samples, n_features)
42
+
43
+ Returns:
44
+ L2-normalized unit vectors (norm = 1.0 per row)
45
+
46
+ Note:
47
+ Zero vectors are returned unchanged (no division by zero)
48
+ """
49
+ l2_norms = np.linalg.norm(vectors, axis=1, keepdims=True)
50
+ l2_norms = np.where(l2_norms == 0, 1.0, l2_norms)
51
+ return vectors / l2_norms
52
+
53
+
54
+ def standardize(vectors: NDArray[np.floating[Any]]) -> NDArray[np.floating[Any]]:
55
+ """Z-score standardization: (x - mean) / std per feature.
56
+
57
+ Args:
58
+ vectors: Input vectors of shape (n_samples, n_features)
59
+
60
+ Returns:
61
+ Standardized vectors (mean=0, std=1 per feature column)
62
+
63
+ Note:
64
+ Zero-variance features are returned unchanged
65
+ """
66
+ mean = np.mean(vectors, axis=0, keepdims=True)
67
+ std = np.std(vectors, axis=0, keepdims=True)
68
+ std = np.where(std == 0, 1.0, std)
69
+ return (vectors - mean) / std
70
+
71
+
72
+ def normalize(
73
+ vectors: NDArray[np.floating[Any]],
74
+ method: NormalizationType | None = None,
75
+ ) -> NDArray[np.floating[Any]]:
76
+ """Apply normalization to vectors.
77
+
78
+ Args:
79
+ vectors: Input vectors of shape (n_samples, n_features)
80
+ method: Normalization method: 'l1', 'l2', 'standardize', 'none', or None
81
+
82
+ Returns:
83
+ Normalized vectors
84
+
85
+ Raises:
86
+ ValueError: If normalization produces NaN/Inf or method is unknown
87
+
88
+ Example:
89
+ >>> vectors = np.array([[1, 2, 3], [4, 5, 6]])
90
+ >>> normalize(vectors, method='l2')
91
+ array([[0.267, 0.535, 0.802],
92
+ [0.456, 0.570, 0.684]])
93
+ """
94
+ if method is None or method == "none":
95
+ return vectors.copy()
96
+ elif method == "l1":
97
+ normalized = normalize_l1(vectors)
98
+ elif method == "l2":
99
+ normalized = normalize_l2(vectors)
100
+ elif method == "standardize":
101
+ normalized = standardize(vectors)
102
+ else:
103
+ raise ValueError(
104
+ f"Invalid normalization method: '{method}'. "
105
+ "Valid options: 'l1', 'l2', 'standardize', 'none', None"
106
+ )
107
+
108
+ # Validate output
109
+ if not np.all(np.isfinite(normalized)):
110
+ raise ValueError(
111
+ "Normalization produced NaN or Inf values. "
112
+ "This may indicate zero-variance features or numerical instability. "
113
+ f"Normalization method: {method}"
114
+ )
115
+
116
+ return normalized