ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1172 @@
1
+ """Trade-level visualizations for backtest analysis.
2
+
3
+ Provides interactive Plotly plots for deep trade analysis:
4
+ - MFE/MAE scatter plot with exit efficiency
5
+ - Exit reason breakdown (sunburst/treemap)
6
+ - Trade PnL waterfall
7
+ - Duration distribution
8
+ - Size vs return analysis
9
+ - Consecutive wins/losses
10
+
11
+ These visualizations exceed QuantStats by providing trade-level insights
12
+ rather than just portfolio-level aggregates.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import TYPE_CHECKING, Literal
18
+
19
+ import numpy as np
20
+ import plotly.graph_objects as go
21
+ from plotly.subplots import make_subplots
22
+
23
+ from ml4t.diagnostic.visualization.core import (
24
+ create_base_figure,
25
+ get_color_scheme,
26
+ get_theme_config,
27
+ validate_theme,
28
+ )
29
+
30
+ if TYPE_CHECKING:
31
+ import polars as pl
32
+
33
+
34
+ # =============================================================================
35
+ # MFE/MAE Analysis
36
+ # =============================================================================
37
+
38
+
39
+ def plot_mfe_mae_scatter(
40
+ trades_df: pl.DataFrame,
41
+ *,
42
+ color_by: Literal["pnl", "pnl_pct", "duration", "exit_reason", "direction"] = "pnl",
43
+ size_by: Literal["quantity", "notional", "uniform"] = "uniform",
44
+ show_efficiency_frontier: bool = True,
45
+ show_edge_ratio: bool = True,
46
+ show_quadrants: bool = True,
47
+ mfe_col: str = "mfe",
48
+ mae_col: str = "mae",
49
+ theme: str | None = None,
50
+ height: int = 600,
51
+ width: int | None = None,
52
+ ) -> go.Figure:
53
+ """Create MFE vs MAE scatter plot with exit efficiency analysis.
54
+
55
+ Maximum Favorable Excursion (MFE) shows the best unrealized return
56
+ during each trade. Maximum Adverse Excursion (MAE) shows the worst.
57
+ This plot reveals exit timing efficiency.
58
+
59
+ Parameters
60
+ ----------
61
+ trades_df : pl.DataFrame
62
+ Trade data with mfe, mae, pnl columns
63
+ color_by : str, default "pnl"
64
+ Field to use for color encoding
65
+ size_by : str, default "uniform"
66
+ Field to use for marker size
67
+ show_efficiency_frontier : bool, default True
68
+ Show diagonal line where exit equals MFE (perfect efficiency)
69
+ show_edge_ratio : bool, default True
70
+ Show aggregate edge ratio annotation
71
+ show_quadrants : bool, default True
72
+ Show quadrant labels (Q1: winners, Q2-4: losers by type)
73
+ mfe_col : str, default "mfe"
74
+ Column name for MFE
75
+ mae_col : str, default "mae"
76
+ Column name for MAE
77
+ theme : str, optional
78
+ Plot theme
79
+ height : int, default 600
80
+ Figure height
81
+ width : int, optional
82
+ Figure width
83
+
84
+ Returns
85
+ -------
86
+ go.Figure
87
+ Interactive scatter plot
88
+
89
+ Examples
90
+ --------
91
+ >>> fig = plot_mfe_mae_scatter(trades_df, color_by="exit_reason")
92
+ >>> fig.show()
93
+
94
+ Notes
95
+ -----
96
+ Quadrant Interpretation:
97
+ - Q1 (MFE > |MAE|, PnL > 0): Healthy winners with controlled drawdown
98
+ - Q2 (MFE < |MAE|, PnL > 0): Lucky winners that recovered from large drawdown
99
+ - Q3 (MFE < |MAE|, PnL < 0): Losers with insufficient profit opportunity
100
+ - Q4 (MFE > |MAE|, PnL < 0): Poor exit timing - had profit but lost it
101
+ """
102
+
103
+ theme = validate_theme(theme)
104
+
105
+ # Extract data
106
+ mfe = trades_df[mfe_col].to_numpy()
107
+ mae = np.abs(trades_df[mae_col].to_numpy()) # MAE as positive values
108
+ pnl = trades_df["pnl"].to_numpy() if "pnl" in trades_df.columns else np.zeros(len(mfe))
109
+
110
+ # Color encoding
111
+ if color_by == "pnl" and "pnl" in trades_df.columns:
112
+ color_values = pnl
113
+ colorscale = "RdYlGn"
114
+ color_label = "PnL ($)"
115
+ elif color_by == "pnl_pct" and "pnl_pct" in trades_df.columns:
116
+ color_values = trades_df["pnl_pct"].to_numpy()
117
+ colorscale = "RdYlGn"
118
+ color_label = "Return (%)"
119
+ elif color_by == "duration" and "bars_held" in trades_df.columns:
120
+ color_values = trades_df["bars_held"].to_numpy()
121
+ colorscale = "Viridis"
122
+ color_label = "Bars Held"
123
+ elif color_by == "exit_reason" and "exit_reason" in trades_df.columns:
124
+ # Categorical - use discrete colors
125
+ color_values = None
126
+ exit_reasons = trades_df["exit_reason"].to_list()
127
+ elif color_by == "direction" and "direction" in trades_df.columns:
128
+ color_values = None
129
+ directions = trades_df["direction"].to_list()
130
+ else:
131
+ color_values = pnl
132
+ colorscale = "RdYlGn"
133
+ color_label = "PnL ($)"
134
+
135
+ # Size encoding
136
+ if size_by == "quantity" and "quantity" in trades_df.columns:
137
+ sizes = np.abs(trades_df["quantity"].to_numpy())
138
+ sizes = 5 + 20 * (sizes - sizes.min()) / (sizes.max() - sizes.min() + 1e-10)
139
+ elif (
140
+ size_by == "notional"
141
+ and "entry_price" in trades_df.columns
142
+ and "quantity" in trades_df.columns
143
+ ):
144
+ notional = np.abs(trades_df["entry_price"].to_numpy() * trades_df["quantity"].to_numpy())
145
+ sizes = 5 + 20 * (notional - notional.min()) / (notional.max() - notional.min() + 1e-10)
146
+ else:
147
+ sizes = 10 # Uniform size
148
+
149
+ # Create figure
150
+ fig = create_base_figure(
151
+ title="MFE vs MAE Analysis (Exit Efficiency)",
152
+ xaxis_title="MAE (Max Adverse Excursion) - % Loss from Entry",
153
+ yaxis_title="MFE (Max Favorable Excursion) - % Gain from Entry",
154
+ height=height,
155
+ width=width,
156
+ theme=theme,
157
+ )
158
+
159
+ # Hover template
160
+ hover_template = (
161
+ "<b>Trade</b><br>"
162
+ "MFE: %{y:.2%}<br>"
163
+ "MAE: %{x:.2%}<br>"
164
+ "PnL: $%{customdata[0]:.2f}<br>"
165
+ "Return: %{customdata[1]:.2%}<br>"
166
+ "<extra></extra>"
167
+ )
168
+
169
+ # Custom data for hover
170
+ custom_data = np.column_stack(
171
+ [
172
+ pnl,
173
+ trades_df["pnl_pct"].to_numpy() / 100
174
+ if "pnl_pct" in trades_df.columns
175
+ else pnl / 10000,
176
+ ]
177
+ )
178
+
179
+ # Add scatter trace
180
+ if color_by == "exit_reason" and "exit_reason" in trades_df.columns:
181
+ # Discrete color by exit reason
182
+ unique_reasons = list(set(exit_reasons))
183
+ colors = get_color_scheme("set2")
184
+
185
+ for i, reason in enumerate(unique_reasons):
186
+ mask = [r == reason for r in exit_reasons]
187
+ fig.add_trace(
188
+ go.Scatter(
189
+ x=mae[mask],
190
+ y=mfe[mask],
191
+ mode="markers",
192
+ name=reason,
193
+ marker={
194
+ "size": sizes if isinstance(sizes, int) else sizes[mask],
195
+ "color": colors[i % len(colors)],
196
+ "opacity": 0.7,
197
+ "line": {"width": 1, "color": "white"},
198
+ },
199
+ customdata=custom_data[mask],
200
+ hovertemplate=hover_template.replace(
201
+ "<extra></extra>", f"Exit: {reason}<extra></extra>"
202
+ ),
203
+ )
204
+ )
205
+ elif color_by == "direction" and "direction" in trades_df.columns:
206
+ # Long vs Short
207
+ for direction in ["long", "short"]:
208
+ mask = [d == direction for d in directions]
209
+ color = "#28A745" if direction == "long" else "#DC3545"
210
+ fig.add_trace(
211
+ go.Scatter(
212
+ x=mae[mask],
213
+ y=mfe[mask],
214
+ mode="markers",
215
+ name=direction.title(),
216
+ marker={
217
+ "size": sizes if isinstance(sizes, int) else sizes[mask],
218
+ "color": color,
219
+ "opacity": 0.7,
220
+ "line": {"width": 1, "color": "white"},
221
+ },
222
+ customdata=custom_data[mask],
223
+ hovertemplate=hover_template,
224
+ )
225
+ )
226
+ else:
227
+ # Continuous color scale
228
+ fig.add_trace(
229
+ go.Scatter(
230
+ x=mae,
231
+ y=mfe,
232
+ mode="markers",
233
+ marker={
234
+ "size": sizes,
235
+ "color": color_values,
236
+ "colorscale": colorscale,
237
+ "colorbar": {"title": color_label, "thickness": 15},
238
+ "opacity": 0.7,
239
+ "line": {"width": 1, "color": "white"},
240
+ },
241
+ customdata=custom_data,
242
+ hovertemplate=hover_template,
243
+ showlegend=False,
244
+ )
245
+ )
246
+
247
+ # Add efficiency frontier (diagonal)
248
+ if show_efficiency_frontier:
249
+ max_val = max(mfe.max(), mae.max()) * 1.1
250
+ fig.add_trace(
251
+ go.Scatter(
252
+ x=[0, max_val],
253
+ y=[0, max_val],
254
+ mode="lines",
255
+ name="Perfect Efficiency (Exit at MFE)",
256
+ line={"color": "gray", "dash": "dash", "width": 2},
257
+ hoverinfo="skip",
258
+ )
259
+ )
260
+
261
+ # Add quadrant annotations
262
+ if show_quadrants:
263
+ annotations = [
264
+ {
265
+ "x": mae.max() * 0.8,
266
+ "y": mfe.max() * 0.9,
267
+ "text": "Q1: Healthy Winners",
268
+ "color": "#28A745",
269
+ },
270
+ {
271
+ "x": mae.max() * 0.2,
272
+ "y": mfe.max() * 0.9,
273
+ "text": "Q2: Lucky Recovery",
274
+ "color": "#FFC107",
275
+ },
276
+ {
277
+ "x": mae.max() * 0.2,
278
+ "y": mfe.max() * 0.1,
279
+ "text": "Q3: No Opportunity",
280
+ "color": "#DC3545",
281
+ },
282
+ {
283
+ "x": mae.max() * 0.8,
284
+ "y": mfe.max() * 0.1,
285
+ "text": "Q4: Poor Exit",
286
+ "color": "#DC3545",
287
+ },
288
+ ]
289
+
290
+ for ann in annotations:
291
+ fig.add_annotation(
292
+ x=ann["x"],
293
+ y=ann["y"],
294
+ text=ann["text"],
295
+ showarrow=False,
296
+ font={"size": 10, "color": ann["color"]},
297
+ opacity=0.7,
298
+ )
299
+
300
+ # Add edge ratio annotation
301
+ if show_edge_ratio:
302
+ edge_ratio = np.mean(mfe) / np.mean(mae) if np.mean(mae) > 0 else np.inf
303
+ efficiency = np.mean(pnl[pnl > 0] / mfe[pnl > 0]) if (pnl > 0).sum() > 0 else 0
304
+
305
+ fig.add_annotation(
306
+ x=0.02,
307
+ y=0.98,
308
+ xref="paper",
309
+ yref="paper",
310
+ text=f"<b>Edge Ratio:</b> {edge_ratio:.2f}<br><b>Exit Efficiency:</b> {efficiency:.1%}",
311
+ showarrow=False,
312
+ font={"size": 12},
313
+ align="left",
314
+ bgcolor="rgba(255,255,255,0.8)",
315
+ bordercolor="gray",
316
+ borderwidth=1,
317
+ )
318
+
319
+ fig.update_layout(
320
+ legend={"yanchor": "top", "y": 0.99, "xanchor": "right", "x": 0.99},
321
+ )
322
+
323
+ return fig
324
+
325
+
326
+ # =============================================================================
327
+ # Exit Reason Analysis
328
+ # =============================================================================
329
+
330
+
331
+ def plot_exit_reason_breakdown(
332
+ trades_df: pl.DataFrame,
333
+ *,
334
+ chart_type: Literal["sunburst", "treemap", "bar", "pie"] = "sunburst",
335
+ show_pnl_contribution: bool = True,
336
+ show_win_loss_split: bool = True,
337
+ exit_reason_col: str = "exit_reason",
338
+ pnl_col: str = "pnl",
339
+ theme: str | None = None,
340
+ height: int = 500,
341
+ width: int | None = None,
342
+ ) -> go.Figure:
343
+ """Create exit reason breakdown visualization.
344
+
345
+ Shows distribution of exit reasons and their PnL contribution.
346
+
347
+ Parameters
348
+ ----------
349
+ trades_df : pl.DataFrame
350
+ Trade data with exit_reason and pnl columns
351
+ chart_type : str, default "sunburst"
352
+ Type of chart: "sunburst", "treemap", "bar", or "pie"
353
+ show_pnl_contribution : bool, default True
354
+ Show PnL contribution rather than just count
355
+ show_win_loss_split : bool, default True
356
+ Split by winner/loser within each exit reason
357
+ exit_reason_col : str, default "exit_reason"
358
+ Column name for exit reason
359
+ pnl_col : str, default "pnl"
360
+ Column name for PnL
361
+ theme : str, optional
362
+ Plot theme
363
+ height : int, default 500
364
+ Figure height
365
+ width : int, optional
366
+ Figure width
367
+
368
+ Returns
369
+ -------
370
+ go.Figure
371
+ Exit reason breakdown chart
372
+
373
+ Examples
374
+ --------
375
+ >>> fig = plot_exit_reason_breakdown(trades_df, chart_type="sunburst")
376
+ >>> fig.show()
377
+ """
378
+ import polars as pl
379
+
380
+ theme = validate_theme(theme)
381
+ theme_config = get_theme_config(theme)
382
+
383
+ # Prepare data
384
+ if show_win_loss_split:
385
+ # Add win/loss classification
386
+ trades_with_outcome = trades_df.with_columns(
387
+ pl.when(pl.col(pnl_col) > 0)
388
+ .then(pl.lit("Winner"))
389
+ .otherwise(pl.lit("Loser"))
390
+ .alias("outcome")
391
+ )
392
+
393
+ grouped = trades_with_outcome.group_by([exit_reason_col, "outcome"]).agg(
394
+ [
395
+ pl.count().alias("count"),
396
+ pl.col(pnl_col).sum().alias("total_pnl"),
397
+ pl.col(pnl_col).mean().alias("avg_pnl"),
398
+ ]
399
+ )
400
+
401
+ # Build hierarchical data
402
+ labels = ["All Trades"]
403
+ parents = [""]
404
+ values = []
405
+ colors = []
406
+
407
+ total_trades = len(trades_df)
408
+ total_pnl = trades_df[pnl_col].sum()
409
+
410
+ values.append(total_trades if not show_pnl_contribution else abs(total_pnl))
411
+ colors.append("#6C757D")
412
+
413
+ # Add exit reasons
414
+ for reason in grouped[exit_reason_col].unique().to_list():
415
+ reason_data = grouped.filter(pl.col(exit_reason_col) == reason)
416
+ reason_count = reason_data["count"].sum()
417
+ reason_pnl = reason_data["total_pnl"].sum()
418
+
419
+ labels.append(reason)
420
+ parents.append("All Trades")
421
+ values.append(reason_count if not show_pnl_contribution else abs(reason_pnl))
422
+ colors.append("#3498DB" if reason_pnl > 0 else "#E74C3C")
423
+
424
+ # Add win/loss under each reason
425
+ for outcome in ["Winner", "Loser"]:
426
+ outcome_data = reason_data.filter(pl.col("outcome") == outcome)
427
+ if len(outcome_data) > 0:
428
+ outcome_count = outcome_data["count"].sum()
429
+ outcome_pnl = outcome_data["total_pnl"].sum()
430
+
431
+ labels.append(f"{reason} - {outcome}")
432
+ parents.append(reason)
433
+ values.append(outcome_count if not show_pnl_contribution else abs(outcome_pnl))
434
+ colors.append("#28A745" if outcome == "Winner" else "#DC3545")
435
+ else:
436
+ # Simple grouping
437
+ grouped = trades_df.group_by(exit_reason_col).agg(
438
+ [
439
+ pl.count().alias("count"),
440
+ pl.col(pnl_col).sum().alias("total_pnl"),
441
+ ]
442
+ )
443
+
444
+ labels = grouped[exit_reason_col].to_list()
445
+ values = (
446
+ grouped["count"].to_list()
447
+ if not show_pnl_contribution
448
+ else [abs(p) for p in grouped["total_pnl"].to_list()]
449
+ )
450
+ colors = ["#28A745" if p > 0 else "#DC3545" for p in grouped["total_pnl"].to_list()]
451
+ parents = None
452
+
453
+ # Create chart
454
+ if chart_type == "sunburst" and show_win_loss_split:
455
+ fig = go.Figure(
456
+ go.Sunburst(
457
+ labels=labels,
458
+ parents=parents,
459
+ values=values,
460
+ marker={"colors": colors},
461
+ branchvalues="total",
462
+ hovertemplate="<b>%{label}</b><br>Count: %{value}<extra></extra>",
463
+ )
464
+ )
465
+ elif chart_type == "treemap" and show_win_loss_split:
466
+ fig = go.Figure(
467
+ go.Treemap(
468
+ labels=labels,
469
+ parents=parents,
470
+ values=values,
471
+ marker={"colors": colors},
472
+ branchvalues="total",
473
+ hovertemplate="<b>%{label}</b><br>Count: %{value}<extra></extra>",
474
+ )
475
+ )
476
+ elif chart_type == "pie":
477
+ fig = go.Figure(
478
+ go.Pie(
479
+ labels=labels if not show_win_loss_split else grouped[exit_reason_col].to_list(),
480
+ values=values if not show_win_loss_split else grouped["count"].to_list(),
481
+ marker={"colors": colors if not show_win_loss_split else None},
482
+ hole=0.4,
483
+ hovertemplate="<b>%{label}</b><br>Count: %{value}<br>%{percent}<extra></extra>",
484
+ )
485
+ )
486
+ else: # bar
487
+ exit_reasons = grouped[exit_reason_col].to_list()
488
+ counts = grouped["count"].to_list()
489
+ pnls = grouped["total_pnl"].to_list()
490
+ bar_colors = ["#28A745" if p > 0 else "#DC3545" for p in pnls]
491
+
492
+ fig = go.Figure()
493
+
494
+ fig.add_trace(
495
+ go.Bar(
496
+ x=exit_reasons,
497
+ y=counts,
498
+ marker_color=bar_colors,
499
+ text=[f"${p:,.0f}" for p in pnls],
500
+ textposition="outside",
501
+ hovertemplate="<b>%{x}</b><br>Count: %{y}<br>Total PnL: %{text}<extra></extra>",
502
+ )
503
+ )
504
+
505
+ fig.update_layout(
506
+ xaxis_title="Exit Reason",
507
+ yaxis_title="Number of Trades",
508
+ )
509
+
510
+ value_type = "PnL Contribution" if show_pnl_contribution else "Trade Count"
511
+ fig.update_layout(
512
+ title=f"Exit Reason Breakdown ({value_type})",
513
+ height=height,
514
+ width=width,
515
+ **{k: v for k, v in theme_config["layout"].items() if k != "margin"},
516
+ )
517
+
518
+ return fig
519
+
520
+
521
+ # =============================================================================
522
+ # Trade Waterfall
523
+ # =============================================================================
524
+
525
+
526
+ def plot_trade_waterfall(
527
+ trades_df: pl.DataFrame,
528
+ *,
529
+ n_trades: int | None = None,
530
+ sort_by: Literal["time", "pnl", "abs_pnl"] = "time",
531
+ show_cumulative_line: bool = True,
532
+ group_by_day: bool = False,
533
+ initial_equity: float = 100000.0,
534
+ pnl_col: str = "pnl",
535
+ time_col: str = "exit_time",
536
+ theme: str | None = None,
537
+ height: int = 500,
538
+ width: int | None = None,
539
+ ) -> go.Figure:
540
+ """Create trade-by-trade PnL waterfall chart.
541
+
542
+ Shows each trade's contribution to cumulative PnL.
543
+
544
+ Parameters
545
+ ----------
546
+ trades_df : pl.DataFrame
547
+ Trade data with pnl column
548
+ n_trades : int, optional
549
+ Limit to last N trades. None for all.
550
+ sort_by : str, default "time"
551
+ How to order trades: "time", "pnl", or "abs_pnl"
552
+ show_cumulative_line : bool, default True
553
+ Overlay cumulative PnL line
554
+ group_by_day : bool, default False
555
+ Aggregate by day (useful for high-frequency strategies)
556
+ initial_equity : float, default 100000.0
557
+ Starting equity for cumulative calculation
558
+ pnl_col : str, default "pnl"
559
+ Column name for PnL
560
+ time_col : str, default "exit_time"
561
+ Column name for trade time
562
+ theme : str, optional
563
+ Plot theme
564
+ height : int, default 500
565
+ Figure height
566
+ width : int, optional
567
+ Figure width
568
+
569
+ Returns
570
+ -------
571
+ go.Figure
572
+ Waterfall chart of trade PnL
573
+
574
+ Examples
575
+ --------
576
+ >>> fig = plot_trade_waterfall(trades_df, n_trades=50, show_cumulative_line=True)
577
+ >>> fig.show()
578
+ """
579
+ import polars as pl
580
+
581
+ theme = validate_theme(theme)
582
+ theme_config = get_theme_config(theme)
583
+
584
+ # Sort and limit
585
+ if sort_by == "time" and time_col in trades_df.columns:
586
+ trades = trades_df.sort(time_col)
587
+ elif sort_by == "pnl":
588
+ trades = trades_df.sort(pnl_col, descending=True)
589
+ elif sort_by == "abs_pnl":
590
+ trades = trades_df.with_columns(pl.col(pnl_col).abs().alias("_abs_pnl")).sort(
591
+ "_abs_pnl", descending=True
592
+ )
593
+ else:
594
+ trades = trades_df
595
+
596
+ if n_trades is not None:
597
+ trades = trades.tail(n_trades)
598
+
599
+ # Group by day if requested
600
+ if group_by_day and time_col in trades.columns:
601
+ trades = (
602
+ trades.with_columns(pl.col(time_col).dt.date().alias("date"))
603
+ .group_by("date")
604
+ .agg(
605
+ [
606
+ pl.col(pnl_col).sum().alias(pnl_col),
607
+ pl.count().alias("n_trades"),
608
+ ]
609
+ )
610
+ .sort("date")
611
+ )
612
+ x_labels = [str(d) for d in trades["date"].to_list()]
613
+ hover_extra = "<br>Trades: %{customdata[1]}"
614
+ custom_data = np.column_stack(
615
+ [
616
+ trades[pnl_col].to_numpy(),
617
+ trades["n_trades"].to_numpy(),
618
+ ]
619
+ )
620
+ else:
621
+ x_labels = [f"Trade {i + 1}" for i in range(len(trades))]
622
+ hover_extra = ""
623
+ custom_data = trades[pnl_col].to_numpy().reshape(-1, 1)
624
+
625
+ pnl_values = trades[pnl_col].to_numpy()
626
+ cumulative = np.cumsum(pnl_values)
627
+
628
+ # Create figure with secondary y-axis
629
+ fig = make_subplots(specs=[[{"secondary_y": True}]])
630
+
631
+ # Waterfall bars
632
+ colors = ["#28A745" if p > 0 else "#DC3545" for p in pnl_values]
633
+
634
+ fig.add_trace(
635
+ go.Bar(
636
+ x=x_labels,
637
+ y=pnl_values,
638
+ marker_color=colors,
639
+ name="Trade PnL",
640
+ hovertemplate=f"<b>%{{x}}</b><br>PnL: $%{{y:,.2f}}{hover_extra}<extra></extra>",
641
+ customdata=custom_data,
642
+ ),
643
+ secondary_y=False,
644
+ )
645
+
646
+ # Cumulative line
647
+ if show_cumulative_line:
648
+ fig.add_trace(
649
+ go.Scatter(
650
+ x=x_labels,
651
+ y=initial_equity + cumulative,
652
+ mode="lines+markers",
653
+ name="Cumulative Equity",
654
+ line={"color": "#2E86AB", "width": 2},
655
+ marker={"size": 4},
656
+ hovertemplate="<b>%{x}</b><br>Equity: $%{y:,.2f}<extra></extra>",
657
+ ),
658
+ secondary_y=True,
659
+ )
660
+
661
+ # Add zero line
662
+ fig.add_hline(y=0, line_dash="solid", line_color="gray", line_width=1, secondary_y=False)
663
+
664
+ # Update layout
665
+ fig.update_layout(
666
+ title="Trade PnL Waterfall",
667
+ height=height,
668
+ width=width,
669
+ legend={"yanchor": "top", "y": 0.99, "xanchor": "left", "x": 0.01},
670
+ **{k: v for k, v in theme_config["layout"].items() if k != "margin"},
671
+ )
672
+
673
+ fig.update_yaxes(title_text="Trade PnL ($)", secondary_y=False)
674
+ fig.update_yaxes(title_text="Cumulative Equity ($)", secondary_y=True)
675
+ fig.update_xaxes(title_text="Trade" if not group_by_day else "Date")
676
+
677
+ # Rotate x labels if many trades
678
+ if len(x_labels) > 20:
679
+ fig.update_xaxes(tickangle=45)
680
+
681
+ return fig
682
+
683
+
684
+ # =============================================================================
685
+ # Duration Distribution
686
+ # =============================================================================
687
+
688
+
689
+ def plot_trade_duration_distribution(
690
+ trades_df: pl.DataFrame,
691
+ *,
692
+ duration_col: str = "bars_held",
693
+ split_by: Literal["outcome", "exit_reason", "direction", "none"] = "outcome",
694
+ pnl_col: str = "pnl",
695
+ bin_count: int = 30,
696
+ show_statistics: bool = True,
697
+ theme: str | None = None,
698
+ height: int = 450,
699
+ width: int | None = None,
700
+ ) -> go.Figure:
701
+ """Plot distribution of trade holding periods.
702
+
703
+ Parameters
704
+ ----------
705
+ trades_df : pl.DataFrame
706
+ Trade data with duration column
707
+ duration_col : str, default "bars_held"
708
+ Column name for holding period
709
+ split_by : str, default "outcome"
710
+ How to split distribution: "outcome", "exit_reason", "direction", or "none"
711
+ pnl_col : str, default "pnl"
712
+ Column name for PnL (used for outcome split)
713
+ bin_count : int, default 30
714
+ Number of histogram bins
715
+ show_statistics : bool, default True
716
+ Show mean/median annotations
717
+ theme : str, optional
718
+ Plot theme
719
+ height : int, default 450
720
+ Figure height
721
+ width : int, optional
722
+ Figure width
723
+
724
+ Returns
725
+ -------
726
+ go.Figure
727
+ Duration distribution histogram
728
+
729
+ Examples
730
+ --------
731
+ >>> fig = plot_trade_duration_distribution(trades_df, split_by="outcome")
732
+ >>> fig.show()
733
+ """
734
+
735
+ theme = validate_theme(theme)
736
+ theme_config = get_theme_config(theme)
737
+
738
+ fig = create_base_figure(
739
+ title="Trade Duration Distribution",
740
+ xaxis_title="Holding Period (bars)",
741
+ yaxis_title="Number of Trades",
742
+ height=height,
743
+ width=width,
744
+ theme=theme,
745
+ )
746
+
747
+ durations = trades_df[duration_col].to_numpy()
748
+
749
+ if split_by == "outcome" and pnl_col in trades_df.columns:
750
+ winners = durations[trades_df[pnl_col].to_numpy() > 0]
751
+ losers = durations[trades_df[pnl_col].to_numpy() <= 0]
752
+
753
+ fig.add_trace(
754
+ go.Histogram(
755
+ x=winners,
756
+ name="Winners",
757
+ marker_color="#28A745",
758
+ opacity=0.7,
759
+ nbinsx=bin_count,
760
+ )
761
+ )
762
+ fig.add_trace(
763
+ go.Histogram(
764
+ x=losers,
765
+ name="Losers",
766
+ marker_color="#DC3545",
767
+ opacity=0.7,
768
+ nbinsx=bin_count,
769
+ )
770
+ )
771
+ fig.update_layout(barmode="overlay")
772
+
773
+ elif split_by == "exit_reason" and "exit_reason" in trades_df.columns:
774
+ exit_reasons = trades_df["exit_reason"].unique().to_list()
775
+ colors = get_color_scheme("set2")
776
+
777
+ for i, reason in enumerate(exit_reasons):
778
+ mask = trades_df["exit_reason"].to_numpy() == reason
779
+ fig.add_trace(
780
+ go.Histogram(
781
+ x=durations[mask],
782
+ name=reason,
783
+ marker_color=colors[i % len(colors)],
784
+ opacity=0.7,
785
+ nbinsx=bin_count,
786
+ )
787
+ )
788
+ fig.update_layout(barmode="stack")
789
+
790
+ elif split_by == "direction" and "direction" in trades_df.columns:
791
+ for direction in ["long", "short"]:
792
+ mask = trades_df["direction"].to_numpy() == direction
793
+ color = "#28A745" if direction == "long" else "#DC3545"
794
+ fig.add_trace(
795
+ go.Histogram(
796
+ x=durations[mask],
797
+ name=direction.title(),
798
+ marker_color=color,
799
+ opacity=0.7,
800
+ nbinsx=bin_count,
801
+ )
802
+ )
803
+ fig.update_layout(barmode="overlay")
804
+
805
+ else:
806
+ fig.add_trace(
807
+ go.Histogram(
808
+ x=durations,
809
+ name="All Trades",
810
+ marker_color=theme_config["colorway"][0],
811
+ opacity=0.7,
812
+ nbinsx=bin_count,
813
+ )
814
+ )
815
+
816
+ # Add statistics
817
+ if show_statistics:
818
+ mean_dur = np.mean(durations)
819
+ median_dur = np.median(durations)
820
+
821
+ fig.add_vline(
822
+ x=mean_dur,
823
+ line_dash="dash",
824
+ line_color="#2E86AB",
825
+ annotation_text=f"Mean: {mean_dur:.1f}",
826
+ annotation_position="top",
827
+ )
828
+ fig.add_vline(
829
+ x=median_dur,
830
+ line_dash="dot",
831
+ line_color="#E74C3C",
832
+ annotation_text=f"Median: {median_dur:.1f}",
833
+ annotation_position="bottom",
834
+ )
835
+
836
+ fig.update_layout(
837
+ legend={"yanchor": "top", "y": 0.99, "xanchor": "right", "x": 0.99},
838
+ )
839
+
840
+ return fig
841
+
842
+
843
+ # =============================================================================
844
+ # Size vs Return Analysis
845
+ # =============================================================================
846
+
847
+
848
+ def plot_trade_size_vs_return(
849
+ trades_df: pl.DataFrame,
850
+ *,
851
+ size_metric: Literal["quantity", "notional", "risk_amount"] = "notional",
852
+ return_metric: Literal["pnl", "pnl_pct"] = "pnl_pct",
853
+ show_regression: bool = True,
854
+ show_correlation: bool = True,
855
+ color_by: Literal["outcome", "exit_reason", "none"] = "outcome",
856
+ theme: str | None = None,
857
+ height: int = 500,
858
+ width: int | None = None,
859
+ ) -> go.Figure:
860
+ """Analyze relationship between position size and returns.
861
+
862
+ Useful for detecting if larger positions perform differently.
863
+
864
+ Parameters
865
+ ----------
866
+ trades_df : pl.DataFrame
867
+ Trade data
868
+ size_metric : str, default "notional"
869
+ Size measure: "quantity", "notional", or "risk_amount"
870
+ return_metric : str, default "pnl_pct"
871
+ Return measure: "pnl" or "pnl_pct"
872
+ show_regression : bool, default True
873
+ Show regression line
874
+ show_correlation : bool, default True
875
+ Show correlation annotation
876
+ color_by : str, default "outcome"
877
+ Color points by: "outcome", "exit_reason", or "none"
878
+ theme : str, optional
879
+ Plot theme
880
+ height : int, default 500
881
+ Figure height
882
+ width : int, optional
883
+ Figure width
884
+
885
+ Returns
886
+ -------
887
+ go.Figure
888
+ Size vs return scatter plot
889
+ """
890
+
891
+ theme = validate_theme(theme)
892
+ theme_config = get_theme_config(theme)
893
+
894
+ # Calculate size metric
895
+ if size_metric == "quantity" and "quantity" in trades_df.columns:
896
+ sizes = np.abs(trades_df["quantity"].to_numpy())
897
+ x_label = "Position Size (units)"
898
+ elif (
899
+ size_metric == "notional"
900
+ and "entry_price" in trades_df.columns
901
+ and "quantity" in trades_df.columns
902
+ ):
903
+ sizes = np.abs(trades_df["entry_price"].to_numpy() * trades_df["quantity"].to_numpy())
904
+ x_label = "Notional Value ($)"
905
+ elif size_metric == "risk_amount" and "entry_price" in trades_df.columns:
906
+ sizes = np.abs(trades_df["entry_price"].to_numpy() * trades_df["quantity"].to_numpy())
907
+ x_label = "Risk Amount ($)"
908
+ else:
909
+ sizes = (
910
+ np.abs(trades_df["quantity"].to_numpy())
911
+ if "quantity" in trades_df.columns
912
+ else np.ones(len(trades_df))
913
+ )
914
+ x_label = "Position Size"
915
+
916
+ # Get returns
917
+ if return_metric == "pnl_pct" and "pnl_pct" in trades_df.columns:
918
+ returns = trades_df["pnl_pct"].to_numpy()
919
+ y_label = "Return (%)"
920
+ else:
921
+ returns = trades_df["pnl"].to_numpy()
922
+ y_label = "PnL ($)"
923
+
924
+ fig = create_base_figure(
925
+ title="Position Size vs Return",
926
+ xaxis_title=x_label,
927
+ yaxis_title=y_label,
928
+ height=height,
929
+ width=width,
930
+ theme=theme,
931
+ )
932
+
933
+ # Color by outcome or exit reason
934
+ if color_by == "outcome" and "pnl" in trades_df.columns:
935
+ winners = trades_df["pnl"].to_numpy() > 0
936
+
937
+ fig.add_trace(
938
+ go.Scatter(
939
+ x=sizes[winners],
940
+ y=returns[winners],
941
+ mode="markers",
942
+ name="Winners",
943
+ marker={"color": "#28A745", "size": 8, "opacity": 0.6},
944
+ )
945
+ )
946
+ fig.add_trace(
947
+ go.Scatter(
948
+ x=sizes[~winners],
949
+ y=returns[~winners],
950
+ mode="markers",
951
+ name="Losers",
952
+ marker={"color": "#DC3545", "size": 8, "opacity": 0.6},
953
+ )
954
+ )
955
+ elif color_by == "exit_reason" and "exit_reason" in trades_df.columns:
956
+ exit_reasons = trades_df["exit_reason"].unique().to_list()
957
+ colors = get_color_scheme("set2")
958
+
959
+ for i, reason in enumerate(exit_reasons):
960
+ mask = trades_df["exit_reason"].to_numpy() == reason
961
+ fig.add_trace(
962
+ go.Scatter(
963
+ x=sizes[mask],
964
+ y=returns[mask],
965
+ mode="markers",
966
+ name=reason,
967
+ marker={"color": colors[i % len(colors)], "size": 8, "opacity": 0.6},
968
+ )
969
+ )
970
+ else:
971
+ fig.add_trace(
972
+ go.Scatter(
973
+ x=sizes,
974
+ y=returns,
975
+ mode="markers",
976
+ name="Trades",
977
+ marker={"color": theme_config["colorway"][0], "size": 8, "opacity": 0.6},
978
+ )
979
+ )
980
+
981
+ # Add regression line
982
+ if show_regression:
983
+ from scipy import stats
984
+
985
+ # Filter NaN values
986
+ valid = np.isfinite(sizes) & np.isfinite(returns)
987
+ if valid.sum() > 2:
988
+ slope, intercept, r_value, p_value, std_err = stats.linregress(
989
+ sizes[valid], returns[valid]
990
+ )
991
+
992
+ x_line = np.array([sizes[valid].min(), sizes[valid].max()])
993
+ y_line = slope * x_line + intercept
994
+
995
+ fig.add_trace(
996
+ go.Scatter(
997
+ x=x_line,
998
+ y=y_line,
999
+ mode="lines",
1000
+ name=f"Regression (R²={r_value**2:.3f})",
1001
+ line={"color": "gray", "dash": "dash", "width": 2},
1002
+ )
1003
+ )
1004
+
1005
+ # Add correlation annotation
1006
+ if show_correlation:
1007
+ valid = np.isfinite(sizes) & np.isfinite(returns)
1008
+ if valid.sum() > 2:
1009
+ corr = np.corrcoef(sizes[valid], returns[valid])[0, 1]
1010
+
1011
+ fig.add_annotation(
1012
+ x=0.02,
1013
+ y=0.98,
1014
+ xref="paper",
1015
+ yref="paper",
1016
+ text=f"<b>Correlation:</b> {corr:.3f}",
1017
+ showarrow=False,
1018
+ font={"size": 12},
1019
+ bgcolor="rgba(255,255,255,0.8)",
1020
+ bordercolor="gray",
1021
+ borderwidth=1,
1022
+ )
1023
+
1024
+ # Add zero line
1025
+ fig.add_hline(y=0, line_dash="solid", line_color="gray", line_width=1)
1026
+
1027
+ fig.update_layout(
1028
+ legend={"yanchor": "top", "y": 0.99, "xanchor": "right", "x": 0.99},
1029
+ )
1030
+
1031
+ return fig
1032
+
1033
+
1034
+ # =============================================================================
1035
+ # Consecutive Wins/Losses Analysis
1036
+ # =============================================================================
1037
+
1038
+
1039
+ def plot_consecutive_analysis(
1040
+ trades_df: pl.DataFrame,
1041
+ *,
1042
+ metric: Literal["wins", "losses", "pnl"] = "wins",
1043
+ pnl_col: str = "pnl",
1044
+ theme: str | None = None,
1045
+ height: int = 450,
1046
+ width: int | None = None,
1047
+ ) -> go.Figure:
1048
+ """Analyze consecutive wins/losses and streaks.
1049
+
1050
+ Parameters
1051
+ ----------
1052
+ trades_df : pl.DataFrame
1053
+ Trade data with pnl column
1054
+ metric : str, default "wins"
1055
+ What to analyze: "wins", "losses", or "pnl" (for cumulative)
1056
+ pnl_col : str, default "pnl"
1057
+ Column name for PnL
1058
+ theme : str, optional
1059
+ Plot theme
1060
+ height : int, default 450
1061
+ Figure height
1062
+ width : int, optional
1063
+ Figure width
1064
+
1065
+ Returns
1066
+ -------
1067
+ go.Figure
1068
+ Streak analysis visualization
1069
+ """
1070
+ theme = validate_theme(theme)
1071
+ theme_config = get_theme_config(theme)
1072
+
1073
+ pnl = trades_df[pnl_col].to_numpy()
1074
+ is_win = pnl > 0
1075
+
1076
+ # Calculate streaks
1077
+ streaks = []
1078
+ current_streak = 0
1079
+ current_type = None
1080
+
1081
+ for win in is_win:
1082
+ if current_type is None:
1083
+ current_type = win
1084
+ current_streak = 1
1085
+ elif win == current_type:
1086
+ current_streak += 1
1087
+ else:
1088
+ streaks.append((current_type, current_streak))
1089
+ current_type = win
1090
+ current_streak = 1
1091
+
1092
+ if current_streak > 0:
1093
+ streaks.append((current_type, current_streak))
1094
+
1095
+ win_streaks = [s[1] for s in streaks if s[0]]
1096
+ loss_streaks = [s[1] for s in streaks if not s[0]]
1097
+
1098
+ # Create subplot
1099
+ fig = make_subplots(
1100
+ rows=1,
1101
+ cols=2,
1102
+ subplot_titles=("Win Streak Distribution", "Loss Streak Distribution"),
1103
+ )
1104
+
1105
+ # Win streaks histogram
1106
+ if win_streaks:
1107
+ fig.add_trace(
1108
+ go.Histogram(
1109
+ x=win_streaks,
1110
+ name="Win Streaks",
1111
+ marker_color="#28A745",
1112
+ opacity=0.7,
1113
+ nbinsx=max(win_streaks) if win_streaks else 10,
1114
+ ),
1115
+ row=1,
1116
+ col=1,
1117
+ )
1118
+
1119
+ # Loss streaks histogram
1120
+ if loss_streaks:
1121
+ fig.add_trace(
1122
+ go.Histogram(
1123
+ x=loss_streaks,
1124
+ name="Loss Streaks",
1125
+ marker_color="#DC3545",
1126
+ opacity=0.7,
1127
+ nbinsx=max(loss_streaks) if loss_streaks else 10,
1128
+ ),
1129
+ row=1,
1130
+ col=2,
1131
+ )
1132
+
1133
+ # Add statistics annotation
1134
+ max_win_streak = max(win_streaks) if win_streaks else 0
1135
+ max_loss_streak = max(loss_streaks) if loss_streaks else 0
1136
+ avg_win_streak = np.mean(win_streaks) if win_streaks else 0
1137
+ avg_loss_streak = np.mean(loss_streaks) if loss_streaks else 0
1138
+
1139
+ fig.add_annotation(
1140
+ x=0.25,
1141
+ y=1.15,
1142
+ xref="paper",
1143
+ yref="paper",
1144
+ text=f"Max: {max_win_streak} | Avg: {avg_win_streak:.1f}",
1145
+ showarrow=False,
1146
+ font={"size": 11, "color": "#28A745"},
1147
+ )
1148
+
1149
+ fig.add_annotation(
1150
+ x=0.75,
1151
+ y=1.15,
1152
+ xref="paper",
1153
+ yref="paper",
1154
+ text=f"Max: {max_loss_streak} | Avg: {avg_loss_streak:.1f}",
1155
+ showarrow=False,
1156
+ font={"size": 11, "color": "#DC3545"},
1157
+ )
1158
+
1159
+ fig.update_layout(
1160
+ title="Consecutive Trade Streak Analysis",
1161
+ height=height,
1162
+ width=width,
1163
+ showlegend=False,
1164
+ **{k: v for k, v in theme_config["layout"].items() if k != "margin"},
1165
+ )
1166
+
1167
+ fig.update_xaxes(title_text="Streak Length", row=1, col=1)
1168
+ fig.update_xaxes(title_text="Streak Length", row=1, col=2)
1169
+ fig.update_yaxes(title_text="Frequency", row=1, col=1)
1170
+ fig.update_yaxes(title_text="Frequency", row=1, col=2)
1171
+
1172
+ return fig