ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1136 @@
1
+ """Trade-level analysis for backtest diagnostics and SHAP attribution.
2
+
3
+ This module provides tools for analyzing individual trades from backtests,
4
+ identifying worst/best performers, and computing trade-level statistics.
5
+
6
+ Core Components:
7
+ - TradeMetrics: Enriched trade data with computed metrics
8
+ - TradeAnalysis: Main analyzer for extracting worst/best trades
9
+ - TradeStatistics: Aggregate statistics across trades
10
+ - TradeAnalysisResult: Result schema with serialization
11
+
12
+ Integration with ml4t-diagnostics workflow:
13
+ 1. Load backtest results → Extract trades (TradeRecord instances)
14
+ 2. Analyze trades → Identify worst performers (TradeAnalysis)
15
+ 3. Compute statistics → Understand trade distribution (TradeStatistics)
16
+ 4. Feed to SHAP → Explain failures (trade_shap_diagnostics.py)
17
+
18
+ Example - Basic usage:
19
+ >>> from ml4t.diagnostic.integration import TradeRecord
20
+ >>> from ml4t.diagnostic.evaluation import TradeAnalysis
21
+ >>> from datetime import datetime, timedelta
22
+ >>>
23
+ >>> # Create trade records from backtest
24
+ >>> trades = [
25
+ ... TradeRecord(
26
+ ... timestamp=datetime(2024, 1, 15),
27
+ ... symbol="AAPL",
28
+ ... entry_price=150.0,
29
+ ... exit_price=155.0,
30
+ ... pnl=500.0,
31
+ ... duration=timedelta(days=5),
32
+ ... direction="long"
33
+ ... ),
34
+ ... # ... more trades
35
+ ... ]
36
+ >>>
37
+ >>> # Analyze trades
38
+ >>> analyzer = TradeAnalysis(trades)
39
+ >>> worst = analyzer.worst_trades(n=10)
40
+ >>> best = analyzer.best_trades(n=10)
41
+ >>> stats = analyzer.compute_statistics()
42
+ >>>
43
+ >>> print(f"Win rate: {stats.win_rate:.2%}")
44
+ >>> print(f"Average PnL: ${stats.avg_pnl:.2f}")
45
+
46
+ Example - Advanced usage with config:
47
+ >>> from ml4t.diagnostic.config import TradeConfig, ExtractionSettings, FilterSettings
48
+ >>>
49
+ >>> config = TradeConfig(
50
+ ... extraction=ExtractionSettings(n_worst=20, n_best=10),
51
+ ... filter=FilterSettings(
52
+ ... min_duration=timedelta(hours=1),
53
+ ... min_pnl=-1000.0
54
+ ... )
55
+ ... )
56
+ >>>
57
+ >>> analyzer = TradeAnalysis.from_config(trades, config)
58
+ >>> result = analyzer.analyze()
59
+ >>>
60
+ >>> # Export for storage
61
+ >>> result.to_json_string()
62
+ >>> result.get_dataframe("worst_trades")
63
+ >>> result.get_dataframe("statistics")
64
+
65
+ Example - Integration with SHAP diagnostics:
66
+ >>> from ml4t.diagnostic.evaluation import TradeShapAnalyzer
67
+ >>>
68
+ >>> # Get worst trades
69
+ >>> worst_trades = analyzer.worst_trades(n=20)
70
+ >>>
71
+ >>> # Explain with SHAP
72
+ >>> shap_analyzer = TradeShapAnalyzer(model, features, shap_values)
73
+ >>> patterns = shap_analyzer.explain_worst_trades(worst_trades)
74
+ >>>
75
+ >>> for pattern in patterns:
76
+ ... print(pattern.hypothesis)
77
+ ... print(pattern.actions)
78
+ """
79
+
80
+ from __future__ import annotations
81
+
82
+ import heapq
83
+ from datetime import UTC, datetime, timedelta
84
+ from typing import Any, Literal, SupportsFloat, cast
85
+
86
+ import polars as pl
87
+ from pydantic import BaseModel, Field, field_validator
88
+
89
+ from ml4t.diagnostic.integration.backtest_contract import TradeRecord
90
+
91
+
92
+ class TradeMetrics(BaseModel):
93
+ """Enriched trade data with computed metrics for analysis.
94
+
95
+ Extends TradeRecord with additional computed fields useful for
96
+ trade analysis, ranking, and diagnostics. Provides methods for
97
+ DataFrame conversion and serialization.
98
+
99
+ This class wraps TradeRecord and adds:
100
+ - Return percentage calculation
101
+ - Duration in hours/days for easy filtering
102
+ - Return per day (annualized-like metric)
103
+ - Ranking helpers
104
+
105
+ Required Fields (from TradeRecord):
106
+ timestamp: Trade exit timestamp
107
+ symbol: Asset symbol
108
+ entry_price: Average entry price
109
+ exit_price: Average exit price
110
+ pnl: Realized profit/loss
111
+ duration: Time between entry and exit
112
+ direction: Trade direction (long/short)
113
+
114
+ Computed Fields:
115
+ return_pct: Return as percentage of entry price
116
+ duration_hours: Duration in hours
117
+ duration_days: Duration in days
118
+ pnl_per_day: PnL normalized by duration
119
+
120
+ Example - Create from TradeRecord:
121
+ >>> trade_record = TradeRecord(
122
+ ... timestamp=datetime(2024, 1, 15),
123
+ ... symbol="AAPL",
124
+ ... entry_price=150.0,
125
+ ... exit_price=155.0,
126
+ ... pnl=500.0,
127
+ ... duration=timedelta(days=5),
128
+ ... direction="long",
129
+ ... quantity=100
130
+ ... )
131
+ >>> metrics = TradeMetrics.from_trade_record(trade_record)
132
+ >>> print(f"Return: {metrics.return_pct:.2%}")
133
+ >>> print(f"PnL per day: ${metrics.pnl_per_day:.2f}")
134
+
135
+ Example - Convert to DataFrame:
136
+ >>> trades = [TradeMetrics.from_trade_record(tr) for tr in trade_records]
137
+ >>> df = TradeMetrics.to_dataframe(trades)
138
+ >>> print(df.select(["symbol", "pnl", "return_pct"]))
139
+ """
140
+
141
+ # Core fields (from TradeRecord)
142
+ timestamp: datetime = Field(..., description="Trade exit timestamp")
143
+ symbol: str = Field(..., min_length=1, description="Asset symbol")
144
+ entry_price: float = Field(..., gt=0.0, description="Average entry price")
145
+ exit_price: float = Field(..., gt=0.0, description="Average exit price")
146
+ pnl: float = Field(..., description="Realized profit/loss")
147
+ duration: timedelta = Field(..., description="Time between entry and exit")
148
+ direction: Literal["long", "short"] | None = Field(None, description="Trade direction")
149
+
150
+ # Optional fields (from TradeRecord)
151
+ quantity: float | None = Field(None, gt=0.0, description="Position size")
152
+ entry_timestamp: datetime | None = Field(None, description="Position entry timestamp")
153
+ fees: float | None = Field(None, ge=0.0, description="Total transaction fees")
154
+ slippage: float | None = Field(None, ge=0.0, description="Slippage cost")
155
+ metadata: dict[str, Any] | None = Field(None, description="Arbitrary metadata")
156
+ regime_info: dict[str, str] | None = Field(None, description="Market regime info")
157
+
158
+ @field_validator("duration")
159
+ @classmethod
160
+ def validate_duration_positive(cls, v: timedelta) -> timedelta:
161
+ """Ensure duration is positive."""
162
+ if v.total_seconds() <= 0:
163
+ raise ValueError(f"Duration must be positive, got {v}")
164
+ return v
165
+
166
+ @property
167
+ def return_pct(self) -> float:
168
+ """Return as percentage of entry price.
169
+
170
+ Formula:
171
+ - Long: (exit_price - entry_price) / entry_price
172
+ - Short: (entry_price - exit_price) / entry_price
173
+ - Unknown: absolute price change / entry_price (unsigned)
174
+
175
+ Returns:
176
+ Return percentage (e.g., 0.05 = 5% return)
177
+
178
+ Example:
179
+ >>> metrics.return_pct # 0.0333 = 3.33%
180
+ """
181
+ if self.direction == "long":
182
+ return (self.exit_price - self.entry_price) / self.entry_price
183
+ elif self.direction == "short":
184
+ return (self.entry_price - self.exit_price) / self.entry_price
185
+ else:
186
+ # Unknown direction - use absolute price change (unsigned return)
187
+ return abs(self.exit_price - self.entry_price) / self.entry_price
188
+
189
+ @property
190
+ def duration_hours(self) -> float:
191
+ """Duration in hours.
192
+
193
+ Returns:
194
+ Duration as float hours
195
+
196
+ Example:
197
+ >>> metrics.duration_hours # 120.5
198
+ """
199
+ return self.duration.total_seconds() / 3600.0
200
+
201
+ @property
202
+ def duration_days(self) -> float:
203
+ """Duration in days.
204
+
205
+ Returns:
206
+ Duration as float days
207
+
208
+ Example:
209
+ >>> metrics.duration_days # 5.02
210
+ """
211
+ return self.duration.total_seconds() / 86400.0
212
+
213
+ @property
214
+ def pnl_per_day(self) -> float:
215
+ """PnL normalized by duration in days.
216
+
217
+ Provides a duration-adjusted performance metric. Useful for
218
+ comparing trades of different holding periods.
219
+
220
+ Returns:
221
+ PnL per day (can be negative)
222
+
223
+ Example:
224
+ >>> metrics.pnl_per_day # 100.0 (earned $100/day)
225
+ """
226
+ if self.duration_days == 0:
227
+ return 0.0
228
+ return self.pnl / self.duration_days
229
+
230
+ @classmethod
231
+ def from_trade_record(cls, trade: TradeRecord) -> TradeMetrics:
232
+ """Create TradeMetrics from TradeRecord.
233
+
234
+ Args:
235
+ trade: TradeRecord instance from backtest
236
+
237
+ Returns:
238
+ TradeMetrics with computed fields
239
+
240
+ Example:
241
+ >>> metrics = TradeMetrics.from_trade_record(trade_record)
242
+ """
243
+ return cls(
244
+ timestamp=trade.timestamp,
245
+ symbol=trade.symbol,
246
+ entry_price=trade.entry_price,
247
+ exit_price=trade.exit_price,
248
+ pnl=trade.pnl,
249
+ duration=trade.duration,
250
+ direction=trade.direction,
251
+ quantity=trade.quantity,
252
+ entry_timestamp=trade.entry_timestamp,
253
+ fees=trade.fees,
254
+ slippage=trade.slippage,
255
+ metadata=trade.metadata,
256
+ regime_info=trade.regime_info,
257
+ )
258
+
259
+ def to_dict(self) -> dict[str, Any]:
260
+ """Export to dictionary format.
261
+
262
+ Returns:
263
+ Dictionary with all trade data including computed fields
264
+
265
+ Example:
266
+ >>> metrics.to_dict()
267
+ {
268
+ 'timestamp': '2024-01-15T10:30:00',
269
+ 'symbol': 'AAPL',
270
+ 'pnl': 500.0,
271
+ 'return_pct': 0.0333,
272
+ 'duration_hours': 120.0,
273
+ ...
274
+ }
275
+ """
276
+ data = self.model_dump(mode="json")
277
+ # Convert timedelta to total seconds for JSON compatibility
278
+ if "duration" in data:
279
+ data["duration_seconds"] = self.duration.total_seconds()
280
+ del data["duration"]
281
+ # Include computed properties
282
+ data["return_pct"] = self.return_pct
283
+ data["duration_hours"] = self.duration_hours
284
+ data["duration_days"] = self.duration_days
285
+ data["pnl_per_day"] = self.pnl_per_day
286
+ return data
287
+
288
+ @staticmethod
289
+ def to_dataframe(trades: list[TradeMetrics]) -> pl.DataFrame:
290
+ """Convert list of TradeMetrics to Polars DataFrame.
291
+
292
+ Args:
293
+ trades: List of TradeMetrics instances
294
+
295
+ Returns:
296
+ Polars DataFrame with all trade data and computed metrics
297
+
298
+ Example:
299
+ >>> df = TradeMetrics.to_dataframe(metrics_list)
300
+ >>> print(df.select(["symbol", "pnl", "return_pct"]))
301
+ >>> df.sort("pnl").head(10) # Worst 10 trades
302
+ """
303
+ if not trades:
304
+ # Return empty DataFrame with expected schema (must match non-empty schema)
305
+ return pl.DataFrame(
306
+ schema={
307
+ "timestamp": pl.Datetime,
308
+ "symbol": pl.String,
309
+ "entry_price": pl.Float64,
310
+ "exit_price": pl.Float64,
311
+ "pnl": pl.Float64,
312
+ "duration_seconds": pl.Float64,
313
+ "direction": pl.String,
314
+ "quantity": pl.Float64,
315
+ "entry_timestamp": pl.Datetime,
316
+ "fees": pl.Float64,
317
+ "slippage": pl.Float64,
318
+ "return_pct": pl.Float64,
319
+ "duration_hours": pl.Float64,
320
+ "duration_days": pl.Float64,
321
+ "pnl_per_day": pl.Float64,
322
+ }
323
+ )
324
+
325
+ # Convert to list of dicts
326
+ data = []
327
+ for trade in trades:
328
+ trade_dict = {
329
+ "timestamp": trade.timestamp,
330
+ "symbol": trade.symbol,
331
+ "entry_price": trade.entry_price,
332
+ "exit_price": trade.exit_price,
333
+ "pnl": trade.pnl,
334
+ "duration_seconds": trade.duration.total_seconds(),
335
+ "direction": trade.direction,
336
+ "quantity": trade.quantity,
337
+ "entry_timestamp": trade.entry_timestamp,
338
+ "fees": trade.fees,
339
+ "slippage": trade.slippage,
340
+ "return_pct": trade.return_pct,
341
+ "duration_hours": trade.duration_hours,
342
+ "duration_days": trade.duration_days,
343
+ "pnl_per_day": trade.pnl_per_day,
344
+ }
345
+ data.append(trade_dict)
346
+
347
+ return pl.DataFrame(data)
348
+
349
+
350
+ class TradeFilters(BaseModel):
351
+ """Typed filter configuration for trade analysis.
352
+
353
+ Provides type-safe, validated filtering options instead of raw dict[str, Any].
354
+ All fields are optional - only specified filters are applied.
355
+
356
+ Fields:
357
+ symbols: List of symbols to include (None = all symbols)
358
+ min_duration: Minimum trade duration (None = no minimum)
359
+ min_pnl: Minimum PnL to include (None = no minimum)
360
+ max_pnl: Maximum PnL to include (None = no maximum)
361
+ start_date: Start of date range (None = no start bound)
362
+ end_date: End of date range (None = no end bound)
363
+
364
+ Example:
365
+ >>> filters = TradeFilters(
366
+ ... symbols=["AAPL", "MSFT"],
367
+ ... min_duration=timedelta(hours=1),
368
+ ... min_pnl=-1000.0
369
+ ... )
370
+ >>> analyzer = TradeAnalysis(trades, filters=filters)
371
+ """
372
+
373
+ symbols: list[str] | None = Field(None, description="Symbols to include")
374
+ min_duration: timedelta | None = Field(None, description="Minimum trade duration")
375
+ min_pnl: float | None = Field(None, description="Minimum PnL to include")
376
+ max_pnl: float | None = Field(None, description="Maximum PnL to include")
377
+ start_date: datetime | None = Field(None, description="Start of date range")
378
+ end_date: datetime | None = Field(None, description="End of date range")
379
+
380
+ def to_dict(self) -> dict[str, Any]:
381
+ """Convert to legacy dict format for backward compatibility."""
382
+ result: dict[str, Any] = {}
383
+ if self.symbols is not None:
384
+ result["symbols"] = self.symbols
385
+ if self.min_duration is not None:
386
+ result["min_duration_seconds"] = self.min_duration.total_seconds()
387
+ if self.min_pnl is not None:
388
+ result["min_pnl"] = self.min_pnl
389
+ if self.max_pnl is not None:
390
+ result["max_pnl"] = self.max_pnl
391
+ if self.start_date is not None:
392
+ result["start_date"] = self.start_date
393
+ if self.end_date is not None:
394
+ result["end_date"] = self.end_date
395
+ return result
396
+
397
+
398
+ class TradeStatistics(BaseModel):
399
+ """Aggregate statistics across multiple trades.
400
+
401
+ Computes summary statistics for trade analysis:
402
+ - Win/loss metrics (win rate, profit factor)
403
+ - PnL distribution (mean, std, quartiles, skewness)
404
+ - Duration distribution (mean, median, quartiles)
405
+ - Trade counts and breakdowns
406
+
407
+ Used by TradeAnalysisResult to provide high-level performance summary.
408
+
409
+ Fields:
410
+ n_trades: Total number of trades
411
+ n_winners: Number of profitable trades
412
+ n_losers: Number of losing trades
413
+ win_rate: Fraction of winning trades
414
+ total_pnl: Sum of all PnL
415
+ avg_pnl: Mean PnL per trade
416
+ pnl_std: Standard deviation of PnL
417
+ pnl_skewness: Skewness of PnL distribution
418
+ pnl_kurtosis: Kurtosis of PnL distribution
419
+ pnl_quartiles: 25th, 50th (median), 75th percentiles
420
+ avg_winner: Average PnL of winning trades
421
+ avg_loser: Average PnL of losing trades
422
+ profit_factor: Gross profit / gross loss
423
+ avg_duration_days: Average trade duration in days
424
+ median_duration_days: Median trade duration
425
+ duration_quartiles: Duration percentiles
426
+
427
+ Example:
428
+ >>> stats = TradeStatistics.compute(trades)
429
+ >>> print(f"Win rate: {stats.win_rate:.2%}")
430
+ >>> print(f"Avg PnL: ${stats.avg_pnl:.2f}")
431
+ >>> print(f"Profit factor: {stats.profit_factor:.2f}")
432
+ >>> print(stats.summary())
433
+ """
434
+
435
+ # Trade counts
436
+ n_trades: int = Field(..., ge=0, description="Total number of trades")
437
+ n_winners: int = Field(..., ge=0, description="Number of profitable trades")
438
+ n_losers: int = Field(..., ge=0, description="Number of losing trades")
439
+
440
+ # Win rate and PnL metrics
441
+ win_rate: float = Field(..., ge=0.0, le=1.0, description="Fraction of winning trades")
442
+ total_pnl: float = Field(..., description="Sum of all PnL")
443
+ avg_pnl: float = Field(..., description="Mean PnL per trade")
444
+ pnl_std: float = Field(..., ge=0.0, description="Standard deviation of PnL")
445
+
446
+ # Distribution metrics
447
+ pnl_skewness: float | None = Field(None, description="PnL distribution skewness")
448
+ pnl_kurtosis: float | None = Field(None, description="PnL distribution kurtosis")
449
+ pnl_quartiles: dict[str, float] = Field(..., description="PnL quartiles (q25, q50, q75)")
450
+
451
+ # Winner/loser breakdown
452
+ avg_winner: float | None = Field(None, description="Average PnL of winners")
453
+ avg_loser: float | None = Field(None, description="Average PnL of losers")
454
+ profit_factor: float | None = Field(None, description="Gross profit / gross loss")
455
+
456
+ # Duration metrics
457
+ avg_duration_days: float = Field(..., ge=0.0, description="Average duration in days")
458
+ median_duration_days: float = Field(..., ge=0.0, description="Median duration in days")
459
+ duration_quartiles: dict[str, float] = Field(
460
+ ..., description="Duration quartiles (q25, q50, q75)"
461
+ )
462
+
463
+ @staticmethod
464
+ def compute(trades: list[TradeMetrics]) -> TradeStatistics:
465
+ """Compute statistics from list of trades.
466
+
467
+ Args:
468
+ trades: List of TradeMetrics instances
469
+
470
+ Returns:
471
+ TradeStatistics with all computed metrics
472
+
473
+ Raises:
474
+ ValueError: If trades list is empty
475
+
476
+ Example:
477
+ >>> stats = TradeStatistics.compute(metrics_list)
478
+ """
479
+ if not trades:
480
+ raise ValueError("Cannot compute statistics for empty trade list")
481
+
482
+ # Convert to DataFrame for efficient computation
483
+ df = TradeMetrics.to_dataframe(trades)
484
+
485
+ # Count trades
486
+ n_trades = len(df)
487
+ n_winners = int(df.filter(pl.col("pnl") > 0).height)
488
+ n_losers = int(df.filter(pl.col("pnl") < 0).height)
489
+ win_rate = n_winners / n_trades if n_trades > 0 else 0.0
490
+
491
+ # PnL metrics
492
+ pnl_series = df["pnl"]
493
+ total_pnl = float(cast(SupportsFloat, pnl_series.sum()))
494
+ avg_pnl = float(cast(SupportsFloat, pnl_series.mean()))
495
+ pnl_std_value = pnl_series.std()
496
+ pnl_std = float(cast(SupportsFloat, pnl_std_value)) if pnl_std_value is not None else 0.0
497
+
498
+ # Distribution metrics (requires scipy for skewness/kurtosis)
499
+ try:
500
+ from scipy import stats as scipy_stats
501
+
502
+ pnl_values = pnl_series.to_numpy()
503
+ pnl_skewness = float(scipy_stats.skew(pnl_values))
504
+ pnl_kurtosis = float(scipy_stats.kurtosis(pnl_values))
505
+ except ImportError:
506
+ pnl_skewness = None
507
+ pnl_kurtosis = None
508
+
509
+ # Quartiles
510
+ pnl_q25 = float(cast(SupportsFloat, pnl_series.quantile(0.25)))
511
+ pnl_q50 = float(cast(SupportsFloat, pnl_series.quantile(0.50)))
512
+ pnl_q75 = float(cast(SupportsFloat, pnl_series.quantile(0.75)))
513
+ pnl_quartiles = {"q25": pnl_q25, "q50": pnl_q50, "q75": pnl_q75}
514
+
515
+ # Winner/loser breakdown
516
+ winners = df.filter(pl.col("pnl") > 0)
517
+ losers = df.filter(pl.col("pnl") < 0)
518
+
519
+ avg_winner = (
520
+ float(cast(SupportsFloat, winners["pnl"].mean())) if winners.height > 0 else None
521
+ )
522
+ avg_loser = float(cast(SupportsFloat, losers["pnl"].mean())) if losers.height > 0 else None
523
+
524
+ # Profit factor (only defined if both winners and losers exist)
525
+ gross_profit = float(winners["pnl"].sum()) if winners.height > 0 else 0.0
526
+ gross_loss = abs(float(losers["pnl"].sum())) if losers.height > 0 else 0.0
527
+ if winners.height > 0 and losers.height > 0:
528
+ profit_factor = gross_profit / gross_loss
529
+ else:
530
+ profit_factor = None # Undefined when all winners or all losers
531
+
532
+ # Duration metrics
533
+ duration_series = df["duration_days"]
534
+ avg_duration_days = float(cast(SupportsFloat, duration_series.mean()))
535
+ median_duration_days = float(cast(SupportsFloat, duration_series.median()))
536
+ dur_q25 = float(cast(SupportsFloat, duration_series.quantile(0.25)))
537
+ dur_q50 = float(cast(SupportsFloat, duration_series.quantile(0.50)))
538
+ dur_q75 = float(cast(SupportsFloat, duration_series.quantile(0.75)))
539
+ duration_quartiles = {"q25": dur_q25, "q50": dur_q50, "q75": dur_q75}
540
+
541
+ return TradeStatistics(
542
+ n_trades=n_trades,
543
+ n_winners=n_winners,
544
+ n_losers=n_losers,
545
+ win_rate=win_rate,
546
+ total_pnl=total_pnl,
547
+ avg_pnl=avg_pnl,
548
+ pnl_std=pnl_std,
549
+ pnl_skewness=pnl_skewness,
550
+ pnl_kurtosis=pnl_kurtosis,
551
+ pnl_quartiles=pnl_quartiles,
552
+ avg_winner=avg_winner,
553
+ avg_loser=avg_loser,
554
+ profit_factor=profit_factor,
555
+ avg_duration_days=avg_duration_days,
556
+ median_duration_days=median_duration_days,
557
+ duration_quartiles=duration_quartiles,
558
+ )
559
+
560
+ def summary(self) -> str:
561
+ """Generate human-readable summary of statistics.
562
+
563
+ Returns:
564
+ Formatted summary string
565
+
566
+ Example:
567
+ >>> print(stats.summary())
568
+ Trade Statistics
569
+ ================
570
+ Total trades: 150
571
+ Win rate: 62.67%
572
+ ...
573
+ """
574
+ lines = ["Trade Statistics", "=" * 50]
575
+
576
+ # Trade counts
577
+ lines.append(f"Total trades: {self.n_trades}")
578
+ lines.append(f"Winners: {self.n_winners} | Losers: {self.n_losers}")
579
+ lines.append(f"Win rate: {self.win_rate:.2%}")
580
+ lines.append("")
581
+
582
+ # PnL summary
583
+ lines.append("PnL Metrics")
584
+ lines.append("-" * 50)
585
+ lines.append(f"Total PnL: ${self.total_pnl:,.2f}")
586
+ lines.append(f"Average PnL: ${self.avg_pnl:,.2f} ± ${self.pnl_std:,.2f}")
587
+ if self.avg_winner is not None:
588
+ lines.append(f"Avg winner: ${self.avg_winner:,.2f}")
589
+ if self.avg_loser is not None:
590
+ lines.append(f"Avg loser: ${self.avg_loser:,.2f}")
591
+ if self.profit_factor is not None:
592
+ lines.append(f"Profit factor: {self.profit_factor:.2f}")
593
+ lines.append("")
594
+
595
+ # Distribution
596
+ lines.append("PnL Distribution")
597
+ lines.append("-" * 50)
598
+ lines.append(
599
+ f"Q25: ${self.pnl_quartiles['q25']:,.2f} | "
600
+ f"Median: ${self.pnl_quartiles['q50']:,.2f} | "
601
+ f"Q75: ${self.pnl_quartiles['q75']:,.2f}"
602
+ )
603
+ if self.pnl_skewness is not None:
604
+ lines.append(f"Skewness: {self.pnl_skewness:.3f}")
605
+ if self.pnl_kurtosis is not None:
606
+ lines.append(f"Kurtosis: {self.pnl_kurtosis:.3f}")
607
+ lines.append("")
608
+
609
+ # Duration
610
+ lines.append("Duration Metrics")
611
+ lines.append("-" * 50)
612
+ lines.append(f"Average: {self.avg_duration_days:.2f} days")
613
+ lines.append(f"Median: {self.median_duration_days:.2f} days")
614
+ lines.append(
615
+ f"Q25: {self.duration_quartiles['q25']:.2f} | "
616
+ f"Q50: {self.duration_quartiles['q50']:.2f} | "
617
+ f"Q75: {self.duration_quartiles['q75']:.2f}"
618
+ )
619
+
620
+ return "\n".join(lines)
621
+
622
+ def to_dataframe(self) -> pl.DataFrame:
623
+ """Convert statistics to DataFrame.
624
+
625
+ Returns:
626
+ Single-row DataFrame with all statistics
627
+
628
+ Example:
629
+ >>> df = stats.to_dataframe()
630
+ """
631
+ return pl.DataFrame(
632
+ [
633
+ {
634
+ "n_trades": self.n_trades,
635
+ "n_winners": self.n_winners,
636
+ "n_losers": self.n_losers,
637
+ "win_rate": self.win_rate,
638
+ "total_pnl": self.total_pnl,
639
+ "avg_pnl": self.avg_pnl,
640
+ "pnl_std": self.pnl_std,
641
+ "pnl_skewness": self.pnl_skewness,
642
+ "pnl_kurtosis": self.pnl_kurtosis,
643
+ "pnl_q25": self.pnl_quartiles["q25"],
644
+ "pnl_q50": self.pnl_quartiles["q50"],
645
+ "pnl_q75": self.pnl_quartiles["q75"],
646
+ "avg_winner": self.avg_winner,
647
+ "avg_loser": self.avg_loser,
648
+ "profit_factor": self.profit_factor,
649
+ "avg_duration_days": self.avg_duration_days,
650
+ "median_duration_days": self.median_duration_days,
651
+ "dur_q25": self.duration_quartiles["q25"],
652
+ "dur_q50": self.duration_quartiles["q50"],
653
+ "dur_q75": self.duration_quartiles["q75"],
654
+ }
655
+ ]
656
+ )
657
+
658
+
659
+ class TradeAnalysis:
660
+ """Main analyzer for extracting worst/best trades and computing statistics.
661
+
662
+ Provides high-level API for trade analysis workflows:
663
+ 1. Load trades (TradeRecord instances from backtest)
664
+ 2. Extract worst performers → Feed to SHAP diagnostics
665
+ 3. Extract best performers → Understand success patterns
666
+ 4. Compute statistics → Aggregate performance metrics
667
+
668
+ The analyzer supports filtering by:
669
+ - Symbol (e.g., only analyze AAPL, MSFT)
670
+ - Duration (e.g., trades lasting > 1 hour)
671
+ - PnL range (e.g., exclude small trades)
672
+ - Date range (e.g., trades in Q4 2024)
673
+
674
+ Example - Basic usage:
675
+ >>> analyzer = TradeAnalysis(trade_records)
676
+ >>> worst = analyzer.worst_trades(n=20)
677
+ >>> best = analyzer.best_trades(n=10)
678
+ >>> stats = analyzer.compute_statistics()
679
+ >>> print(stats.summary())
680
+
681
+ Example - With filtering:
682
+ >>> from ml4t.diagnostic.evaluation.trade_analysis import TradeFilters
683
+ >>>
684
+ >>> filters = TradeFilters(
685
+ ... symbols=["AAPL", "MSFT"],
686
+ ... min_duration=timedelta(hours=1),
687
+ ... min_pnl=-1000.0,
688
+ ... start_date=datetime(2024, 10, 1)
689
+ ... )
690
+ >>>
691
+ >>> analyzer = TradeAnalysis(trade_records, filters=filters)
692
+ >>> result = analyzer.analyze(n_worst=20, n_best=10)
693
+
694
+ Example - Integration with config:
695
+ >>> from ml4t.diagnostic.config import TradeConfig, ExtractionSettings
696
+ >>>
697
+ >>> config = TradeConfig(
698
+ ... extraction=ExtractionSettings(n_worst=20, n_best=10),
699
+ ... )
700
+ >>>
701
+ >>> analyzer = TradeAnalysis.from_config(trade_records, config)
702
+ >>> result = analyzer.analyze()
703
+ >>> result.to_json_string()
704
+ """
705
+
706
+ def __init__(
707
+ self,
708
+ trades: list[TradeRecord],
709
+ filter_config: dict[str, Any] | None = None,
710
+ *,
711
+ filters: TradeFilters | None = None,
712
+ ):
713
+ """Initialize analyzer with trades.
714
+
715
+ Args:
716
+ trades: List of TradeRecord instances from backtest
717
+ filter_config: Optional filtering configuration (legacy dict format)
718
+ filters: Optional typed TradeFilters (preferred over filter_config)
719
+
720
+ Example:
721
+ >>> # Using typed filters (preferred)
722
+ >>> filters = TradeFilters(symbols=["AAPL"], min_pnl=-1000)
723
+ >>> analyzer = TradeAnalysis(trades, filters=filters)
724
+ >>>
725
+ >>> # Using legacy dict format
726
+ >>> analyzer = TradeAnalysis(trades, filter_config={"symbols": ["AAPL"]})
727
+ """
728
+ if not trades:
729
+ raise ValueError("Cannot analyze empty trade list")
730
+
731
+ # Convert to TradeMetrics
732
+ self.trades = [TradeMetrics.from_trade_record(t) for t in trades]
733
+
734
+ # Normalize filters to dict format (TradeFilters takes precedence)
735
+ if filters is not None:
736
+ filter_config = filters.to_dict()
737
+
738
+ # Apply filters if provided
739
+ if filter_config:
740
+ self.trades = self._apply_filters(self.trades, filter_config)
741
+
742
+ if not self.trades:
743
+ raise ValueError("No trades remaining after applying filters")
744
+
745
+ @classmethod
746
+ def from_config(
747
+ cls,
748
+ trades: list[TradeRecord],
749
+ config: Any, # TradeConfig - avoid circular import
750
+ ) -> TradeAnalysis:
751
+ """Create analyzer from configuration.
752
+
753
+ Args:
754
+ trades: List of TradeRecord instances
755
+ config: TradeConfig instance
756
+
757
+ Returns:
758
+ TradeAnalysis instance
759
+
760
+ Example:
761
+ >>> from ml4t.diagnostic.config import TradeConfig, ExtractionSettings
762
+ >>> config = TradeConfig(extraction=ExtractionSettings(n_worst=20, n_best=10))
763
+ >>> analyzer = TradeAnalysis.from_config(trades, config)
764
+ """
765
+ # Extract filter config if present
766
+ filter_config = getattr(config, "filters", None)
767
+ return cls(trades, filter_config=filter_config)
768
+
769
+ @staticmethod
770
+ def _apply_filters(
771
+ trades: list[TradeMetrics],
772
+ filters: dict[str, Any],
773
+ ) -> list[TradeMetrics]:
774
+ """Apply filters to trade list in a single pass.
775
+
776
+ Args:
777
+ trades: List of TradeMetrics
778
+ filters: Filter criteria
779
+
780
+ Returns:
781
+ Filtered trade list
782
+ """
783
+ # Pre-extract filter values to avoid repeated dict lookups
784
+ symbols: set[str] | None = None
785
+ if "symbols" in filters and filters["symbols"]:
786
+ symbols = set(filters["symbols"])
787
+
788
+ min_dur: float | None = filters.get("min_duration_seconds")
789
+ min_pnl: float | None = filters.get("min_pnl")
790
+ max_pnl: float | None = filters.get("max_pnl")
791
+ start_date: datetime | None = filters.get("start_date")
792
+ end_date: datetime | None = filters.get("end_date")
793
+
794
+ # Single-pass filtering
795
+ def matches(t: TradeMetrics) -> bool:
796
+ if symbols is not None and t.symbol not in symbols:
797
+ return False
798
+ if min_dur is not None and t.duration.total_seconds() < min_dur:
799
+ return False
800
+ if min_pnl is not None and t.pnl < min_pnl:
801
+ return False
802
+ if max_pnl is not None and t.pnl > max_pnl:
803
+ return False
804
+ if start_date is not None and t.timestamp < start_date:
805
+ return False
806
+ if end_date is not None and t.timestamp > end_date:
807
+ return False
808
+ return True
809
+
810
+ return [t for t in trades if matches(t)]
811
+
812
+ def worst_trades(self, n: int = 10) -> list[TradeMetrics]:
813
+ """Extract N worst trades by PnL.
814
+
815
+ Uses heapq.nsmallest for O(n + k log n) efficiency when k << n.
816
+
817
+ Args:
818
+ n: Number of worst trades to extract
819
+
820
+ Returns:
821
+ List of worst N trades, sorted by PnL (ascending)
822
+
823
+ Example:
824
+ >>> worst = analyzer.worst_trades(n=20)
825
+ >>> for trade in worst[:5]:
826
+ ... print(f"{trade.symbol}: ${trade.pnl:.2f}")
827
+
828
+ See Also
829
+ --------
830
+ best_trades : Extract best performing trades
831
+ compute_statistics : Aggregate performance metrics
832
+ analyze : Complete analysis with worst/best trades
833
+ """
834
+ if n <= 0:
835
+ raise ValueError(f"n must be positive, got {n}")
836
+
837
+ # Use heapq for O(n + k log n) instead of O(n log n) sort
838
+ return heapq.nsmallest(n, self.trades, key=lambda t: t.pnl)
839
+
840
+ def best_trades(self, n: int = 10) -> list[TradeMetrics]:
841
+ """Extract N best trades by PnL.
842
+
843
+ Uses heapq.nlargest for O(n + k log n) efficiency when k << n.
844
+
845
+ Args:
846
+ n: Number of best trades to extract
847
+
848
+ Returns:
849
+ List of best N trades, sorted by PnL (descending)
850
+
851
+ Example:
852
+ >>> best = analyzer.best_trades(n=10)
853
+ >>> for trade in best[:5]:
854
+ ... print(f"{trade.symbol}: ${trade.pnl:.2f}")
855
+
856
+ See Also
857
+ --------
858
+ worst_trades : Extract worst performing trades
859
+ compute_statistics : Aggregate performance metrics
860
+ analyze : Complete analysis with worst/best trades
861
+ """
862
+ if n <= 0:
863
+ raise ValueError(f"n must be positive, got {n}")
864
+
865
+ # Use heapq for O(n + k log n) instead of O(n log n) sort
866
+ return heapq.nlargest(n, self.trades, key=lambda t: t.pnl)
867
+
868
+ def compute_statistics(self) -> TradeStatistics:
869
+ """Compute aggregate statistics across all trades.
870
+
871
+ Returns:
872
+ TradeStatistics with summary metrics
873
+
874
+ Example:
875
+ >>> stats = analyzer.compute_statistics()
876
+ >>> print(f"Win rate: {stats.win_rate:.2%}")
877
+
878
+ See Also
879
+ --------
880
+ TradeStatistics : Statistics result schema
881
+ TradeStatistics.compute : Static method for statistics computation
882
+ analyze : Complete analysis including statistics
883
+ """
884
+ return TradeStatistics.compute(self.trades)
885
+
886
+ def analyze(
887
+ self,
888
+ n_worst: int = 10,
889
+ n_best: int = 10,
890
+ ) -> TradeAnalysisResult:
891
+ """Run complete analysis and return result object.
892
+
893
+ Args:
894
+ n_worst: Number of worst trades to extract
895
+ n_best: Number of best trades to extract
896
+
897
+ Returns:
898
+ TradeAnalysisResult with all data
899
+
900
+ Example:
901
+ >>> result = analyzer.analyze(n_worst=20, n_best=10)
902
+ >>> print(result.summary())
903
+ >>> result.to_json_string()
904
+
905
+ See Also
906
+ --------
907
+ worst_trades : Extract worst trades
908
+ best_trades : Extract best trades
909
+ compute_statistics : Compute aggregate statistics
910
+ TradeAnalysisResult : Result schema with serialization
911
+ """
912
+ return TradeAnalysisResult(
913
+ worst_trades=self.worst_trades(n_worst),
914
+ best_trades=self.best_trades(n_best),
915
+ statistics=self.compute_statistics(),
916
+ n_total_trades=len(self.trades),
917
+ )
918
+
919
+
920
+ class TradeAnalysisResult(BaseModel):
921
+ """Result schema for trade analysis with serialization support.
922
+
923
+ Contains the complete output of a trade analysis:
924
+ - Worst N trades (for SHAP diagnostics)
925
+ - Best N trades (for success pattern analysis)
926
+ - Aggregate statistics across all trades
927
+ - Metadata (total trades analyzed)
928
+
929
+ This schema extends BaseResult to provide:
930
+ - JSON serialization via to_json_string()
931
+ - DataFrame export via get_dataframe()
932
+ - Human-readable summary via summary()
933
+
934
+ Use this to store and retrieve analysis results, or to pass
935
+ data between different stages of the diagnostics workflow.
936
+
937
+ Fields:
938
+ worst_trades: List of worst N trades by PnL
939
+ best_trades: List of best N trades by PnL
940
+ statistics: Aggregate statistics
941
+ n_total_trades: Total trades analyzed (before worst/best filtering)
942
+ analysis_type: Type of analysis ("trade_analysis")
943
+ created_at: ISO timestamp of analysis creation
944
+
945
+ Example - Basic usage:
946
+ >>> result = analyzer.analyze(n_worst=20, n_best=10)
947
+ >>> print(result.summary())
948
+ >>> result.to_json_string()
949
+ >>> df = result.get_dataframe("worst_trades")
950
+
951
+ Example - Serialization:
952
+ >>> # Save to file
953
+ >>> with open("analysis_result.json", "w") as f:
954
+ ... f.write(result.to_json_string())
955
+ >>>
956
+ >>> # Load from file
957
+ >>> with open("analysis_result.json") as f:
958
+ ... data = json.load(f)
959
+ >>> result = TradeAnalysisResult(**data)
960
+
961
+ Example - DataFrame export:
962
+ >>> # Get worst trades as DataFrame
963
+ >>> df_worst = result.get_dataframe("worst_trades")
964
+ >>>
965
+ >>> # Get statistics as DataFrame
966
+ >>> df_stats = result.get_dataframe("statistics")
967
+ >>>
968
+ >>> # Get all available DataFrames
969
+ >>> available = result.list_available_dataframes()
970
+ """
971
+
972
+ # Result fields
973
+ worst_trades: list[TradeMetrics] = Field(
974
+ ...,
975
+ description="List of worst N trades by PnL",
976
+ )
977
+ best_trades: list[TradeMetrics] = Field(
978
+ ...,
979
+ description="List of best N trades by PnL",
980
+ )
981
+ statistics: TradeStatistics = Field(
982
+ ...,
983
+ description="Aggregate statistics across all trades",
984
+ )
985
+ n_total_trades: int = Field(
986
+ ...,
987
+ ge=1,
988
+ description="Total number of trades analyzed",
989
+ )
990
+
991
+ # Metadata fields
992
+ analysis_type: str = Field(
993
+ default="trade_analysis",
994
+ description="Type of analysis performed",
995
+ )
996
+ created_at: datetime = Field(
997
+ default_factory=lambda: datetime.now(UTC),
998
+ description="Analysis creation timestamp (UTC)",
999
+ )
1000
+
1001
+ def to_json_string(self, *, indent: int = 2) -> str:
1002
+ """Export to JSON string.
1003
+
1004
+ Args:
1005
+ indent: Indentation level (None for compact)
1006
+
1007
+ Returns:
1008
+ JSON string representation
1009
+
1010
+ Example:
1011
+ >>> json_str = result.to_json_string()
1012
+ >>> with open("result.json", "w") as f:
1013
+ ... f.write(json_str)
1014
+ """
1015
+ return self.model_dump_json(indent=indent)
1016
+
1017
+ def to_dict(self) -> dict[str, Any]:
1018
+ """Export to Python dictionary.
1019
+
1020
+ Returns:
1021
+ Dictionary representation
1022
+
1023
+ Example:
1024
+ >>> data = result.to_dict()
1025
+ >>> data["statistics"]["win_rate"]
1026
+ """
1027
+ return self.model_dump(mode="python")
1028
+
1029
+ def get_dataframe(self, name: str = "worst_trades") -> pl.DataFrame:
1030
+ """Get results as Polars DataFrame.
1031
+
1032
+ Available DataFrames:
1033
+ - "worst_trades": Worst N trades with all fields
1034
+ - "best_trades": Best N trades with all fields
1035
+ - "statistics": Aggregate statistics (single row)
1036
+ - "all_trades": Combined worst + best trades
1037
+
1038
+ Args:
1039
+ name: DataFrame name to retrieve
1040
+
1041
+ Returns:
1042
+ Polars DataFrame with requested data
1043
+
1044
+ Raises:
1045
+ ValueError: If DataFrame name not available
1046
+
1047
+ Example:
1048
+ >>> df_worst = result.get_dataframe("worst_trades")
1049
+ >>> df_stats = result.get_dataframe("statistics")
1050
+ """
1051
+ if name == "worst_trades":
1052
+ return TradeMetrics.to_dataframe(self.worst_trades)
1053
+ elif name == "best_trades":
1054
+ return TradeMetrics.to_dataframe(self.best_trades)
1055
+ elif name == "statistics":
1056
+ return self.statistics.to_dataframe()
1057
+ elif name == "all_trades":
1058
+ # Combine worst and best
1059
+ all_trades = self.worst_trades + self.best_trades
1060
+ return TradeMetrics.to_dataframe(all_trades)
1061
+ else:
1062
+ available = self.list_available_dataframes()
1063
+ raise ValueError(f"DataFrame '{name}' not available. Available: {', '.join(available)}")
1064
+
1065
+ def list_available_dataframes(self) -> list[str]:
1066
+ """List available DataFrame views.
1067
+
1068
+ Returns:
1069
+ List of available DataFrame names
1070
+
1071
+ Example:
1072
+ >>> result.list_available_dataframes()
1073
+ ['worst_trades', 'best_trades', 'statistics', 'all_trades']
1074
+ """
1075
+ return ["worst_trades", "best_trades", "statistics", "all_trades"]
1076
+
1077
+ def summary(self) -> str:
1078
+ """Generate human-readable summary of analysis.
1079
+
1080
+ Returns:
1081
+ Formatted summary string
1082
+
1083
+ Example:
1084
+ >>> print(result.summary())
1085
+ Trade Analysis Summary
1086
+ ======================
1087
+ ...
1088
+ """
1089
+ lines = ["Trade Analysis Summary", "=" * 60]
1090
+
1091
+ # Overview
1092
+ lines.append(f"Analysis timestamp: {self.created_at.isoformat()}")
1093
+ lines.append(f"Total trades analyzed: {self.n_total_trades}")
1094
+ lines.append(f"Worst trades extracted: {len(self.worst_trades)}")
1095
+ lines.append(f"Best trades extracted: {len(self.best_trades)}")
1096
+ lines.append("")
1097
+
1098
+ # Statistics summary
1099
+ lines.append("Overall Statistics")
1100
+ lines.append("-" * 60)
1101
+ stats = self.statistics
1102
+ lines.append(f"Win rate: {stats.win_rate:.2%}")
1103
+ lines.append(f"Total PnL: ${stats.total_pnl:,.2f}")
1104
+ lines.append(f"Average PnL: ${stats.avg_pnl:,.2f} ± ${stats.pnl_std:,.2f}")
1105
+ if stats.profit_factor is not None:
1106
+ lines.append(f"Profit factor: {stats.profit_factor:.2f}")
1107
+ lines.append(f"Average duration: {stats.avg_duration_days:.2f} days")
1108
+ lines.append("")
1109
+
1110
+ # Worst trades preview
1111
+ lines.append("Worst Trades (Top 5)")
1112
+ lines.append("-" * 60)
1113
+ for i, trade in enumerate(self.worst_trades[:5], 1):
1114
+ lines.append(
1115
+ f"{i}. {trade.symbol}: ${trade.pnl:,.2f} ({trade.return_pct:+.2%}) [{trade.duration_days:.1f}d]"
1116
+ )
1117
+ lines.append("")
1118
+
1119
+ # Best trades preview
1120
+ lines.append("Best Trades (Top 5)")
1121
+ lines.append("-" * 60)
1122
+ for i, trade in enumerate(self.best_trades[:5], 1):
1123
+ lines.append(
1124
+ f"{i}. {trade.symbol}: ${trade.pnl:,.2f} ({trade.return_pct:+.2%}) [{trade.duration_days:.1f}d]"
1125
+ )
1126
+
1127
+ return "\n".join(lines)
1128
+
1129
+ def __repr__(self) -> str:
1130
+ """Concise representation."""
1131
+ return (
1132
+ f"TradeAnalysisResult("
1133
+ f"n_worst={len(self.worst_trades)}, "
1134
+ f"n_best={len(self.best_trades)}, "
1135
+ f"n_total={self.n_total_trades})"
1136
+ )