ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,762 @@
1
+ """Cost attribution visualizations for backtest analysis.
2
+
3
+ Provides interactive Plotly visualizations for understanding
4
+ the impact of transaction costs on strategy performance.
5
+
6
+ Key visualizations:
7
+ - Cost waterfall (Gross → Commission → Slippage → Net)
8
+ - Cost sensitivity analysis (Sharpe degradation as costs increase)
9
+ - Cost over time (rolling cost impact)
10
+ - Cost by asset (identify high-cost positions)
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from typing import TYPE_CHECKING, Literal
16
+
17
+ import numpy as np
18
+ import plotly.graph_objects as go
19
+ from plotly.subplots import make_subplots
20
+
21
+ from ml4t.diagnostic.visualization.core import get_theme_config
22
+
23
+ if TYPE_CHECKING:
24
+ import polars as pl
25
+
26
+
27
+ def plot_cost_waterfall(
28
+ gross_pnl: float,
29
+ commission: float,
30
+ slippage: float,
31
+ net_pnl: float | None = None,
32
+ other_costs: dict[str, float] | None = None,
33
+ title: str = "Cost Attribution Waterfall",
34
+ show_percentages: bool = True,
35
+ theme: str | None = None,
36
+ height: int = 500,
37
+ width: int | None = None,
38
+ ) -> go.Figure:
39
+ """Create a waterfall chart showing gross-to-net PnL decomposition.
40
+
41
+ Visualizes how transaction costs (commission, slippage) erode
42
+ gross trading profits into net returns.
43
+
44
+ Parameters
45
+ ----------
46
+ gross_pnl : float
47
+ Gross profit/loss before costs
48
+ commission : float
49
+ Total commission costs (should be positive, will be shown as negative)
50
+ slippage : float
51
+ Total slippage costs (should be positive, will be shown as negative)
52
+ net_pnl : float, optional
53
+ Net PnL after all costs. If not provided, calculated from inputs.
54
+ other_costs : dict[str, float], optional
55
+ Additional cost categories (e.g., {"Financing": 500, "Fees": 200})
56
+ title : str
57
+ Chart title
58
+ show_percentages : bool
59
+ Whether to show cost as percentage of gross
60
+ theme : str, optional
61
+ Theme name (default, dark, print, presentation)
62
+ height : int
63
+ Figure height in pixels
64
+ width : int, optional
65
+ Figure width in pixels
66
+
67
+ Returns
68
+ -------
69
+ go.Figure
70
+ Plotly figure with waterfall chart
71
+
72
+ Examples
73
+ --------
74
+ >>> fig = plot_cost_waterfall(
75
+ ... gross_pnl=100000,
76
+ ... commission=2500,
77
+ ... slippage=1500,
78
+ ... )
79
+ >>> fig.show()
80
+ """
81
+ theme_config = get_theme_config(theme)
82
+
83
+ # Build cost categories
84
+ labels = ["Gross PnL"]
85
+ values = [gross_pnl]
86
+ measures = ["absolute"]
87
+
88
+ # Add commission
89
+ labels.append("Commission")
90
+ values.append(-abs(commission))
91
+ measures.append("relative")
92
+
93
+ # Add slippage
94
+ labels.append("Slippage")
95
+ values.append(-abs(slippage))
96
+ measures.append("relative")
97
+
98
+ # Add other costs if provided
99
+ if other_costs:
100
+ for name, cost in other_costs.items():
101
+ labels.append(name)
102
+ values.append(-abs(cost))
103
+ measures.append("relative")
104
+
105
+ # Calculate net PnL
106
+ if net_pnl is None:
107
+ total_costs = commission + slippage
108
+ if other_costs:
109
+ total_costs += sum(other_costs.values())
110
+ net_pnl = gross_pnl - total_costs
111
+
112
+ labels.append("Net PnL")
113
+ values.append(net_pnl)
114
+ measures.append("total")
115
+
116
+ # Create hover text with percentages
117
+ if show_percentages and gross_pnl != 0:
118
+ text = [f"${gross_pnl:,.0f}"]
119
+ for val in values[1:-1]:
120
+ pct = abs(val) / abs(gross_pnl) * 100
121
+ text.append(f"${val:,.0f} ({pct:.1f}%)")
122
+ text.append(f"${net_pnl:,.0f}")
123
+ else:
124
+ text = [f"${v:,.0f}" for v in values]
125
+
126
+ # Determine colors
127
+ colors = theme_config["colorway"]
128
+ increasing_color = colors[0] # Usually green/blue
129
+ decreasing_color = colors[1] if len(colors) > 1 else "#EF553B" # Red for costs
130
+ totals_color = colors[2] if len(colors) > 2 else "#636EFA" # Blue for totals
131
+
132
+ fig = go.Figure(
133
+ go.Waterfall(
134
+ name="Cost Attribution",
135
+ orientation="v",
136
+ x=labels,
137
+ y=values,
138
+ measure=measures,
139
+ text=text,
140
+ textposition="outside",
141
+ increasing={"marker": {"color": increasing_color}},
142
+ decreasing={"marker": {"color": decreasing_color}},
143
+ totals={"marker": {"color": totals_color}},
144
+ connector={"line": {"color": "rgba(128, 128, 128, 0.5)", "width": 2}},
145
+ )
146
+ )
147
+
148
+ # Build layout
149
+ layout_updates = {
150
+ "title": {"text": title, "font": {"size": 18}},
151
+ "height": height,
152
+ "yaxis": {"title": "PnL ($)", "tickformat": "$,.0f"},
153
+ "showlegend": False,
154
+ }
155
+ if width:
156
+ layout_updates["width"] = width
157
+
158
+ # Merge theme layout without overwriting explicit settings
159
+ for key, value in theme_config["layout"].items():
160
+ if key not in layout_updates:
161
+ layout_updates[key] = value
162
+
163
+ fig.update_layout(**layout_updates)
164
+
165
+ return fig
166
+
167
+
168
+ def plot_cost_sensitivity(
169
+ returns: pl.Series | np.ndarray,
170
+ base_costs_bps: float = 10.0,
171
+ cost_multipliers: list[float] | None = None,
172
+ trades_per_year: int = 252,
173
+ risk_free_rate: float = 0.0,
174
+ title: str = "Cost Sensitivity Analysis",
175
+ show_breakeven: bool = True,
176
+ theme: str | None = None,
177
+ height: int = 500,
178
+ width: int | None = None,
179
+ ) -> go.Figure:
180
+ """Analyze how Sharpe ratio degrades as transaction costs increase.
181
+
182
+ Shows the sensitivity of risk-adjusted returns to transaction costs,
183
+ helping identify the breakeven point where strategy becomes unprofitable.
184
+
185
+ Parameters
186
+ ----------
187
+ returns : pl.Series or np.ndarray
188
+ Gross daily returns (before costs)
189
+ base_costs_bps : float
190
+ Base transaction cost in basis points (e.g., 10 = 0.1%)
191
+ cost_multipliers : list[float], optional
192
+ Multipliers to test (default: [0, 0.5, 1, 1.5, 2, 3, 5])
193
+ trades_per_year : int
194
+ Estimated number of trades per year for cost impact
195
+ risk_free_rate : float
196
+ Annual risk-free rate for Sharpe calculation
197
+ title : str
198
+ Chart title
199
+ show_breakeven : bool
200
+ Whether to annotate the breakeven cost level
201
+ theme : str, optional
202
+ Theme name
203
+ height : int
204
+ Figure height in pixels
205
+ width : int, optional
206
+ Figure width in pixels
207
+
208
+ Returns
209
+ -------
210
+ go.Figure
211
+ Plotly figure with cost sensitivity chart
212
+ """
213
+ import polars as pl
214
+
215
+ theme_config = get_theme_config(theme)
216
+
217
+ # Convert to numpy
218
+ if isinstance(returns, pl.Series):
219
+ returns_arr = returns.to_numpy()
220
+ else:
221
+ returns_arr = np.asarray(returns)
222
+
223
+ # Default multipliers
224
+ if cost_multipliers is None:
225
+ cost_multipliers = [0, 0.5, 1, 1.5, 2, 3, 5]
226
+
227
+ # Calculate metrics at each cost level
228
+ cost_levels = []
229
+ sharpe_values = []
230
+ cagr_values = []
231
+
232
+ gross_mean = np.mean(returns_arr)
233
+ gross_std = np.std(returns_arr, ddof=1)
234
+
235
+ for mult in cost_multipliers:
236
+ # Cost per trade in decimal
237
+ cost_per_trade = (base_costs_bps * mult) / 10000
238
+
239
+ # Estimate daily cost drag (assuming uniform trading)
240
+ daily_cost_drag = cost_per_trade * (trades_per_year / 252)
241
+
242
+ # Net returns
243
+ net_mean = gross_mean - daily_cost_drag
244
+
245
+ # Calculate Sharpe
246
+ if gross_std > 0:
247
+ sharpe = (net_mean - risk_free_rate / 252) / gross_std * np.sqrt(252)
248
+ else:
249
+ sharpe = 0
250
+
251
+ # Calculate CAGR (approximate)
252
+ cagr = (1 + net_mean) ** 252 - 1
253
+
254
+ cost_levels.append(base_costs_bps * mult)
255
+ sharpe_values.append(sharpe)
256
+ cagr_values.append(cagr * 100) # As percentage
257
+
258
+ colors = theme_config["colorway"]
259
+
260
+ # Create subplot with Sharpe and CAGR
261
+ fig = make_subplots(
262
+ rows=1,
263
+ cols=2,
264
+ subplot_titles=("Sharpe Ratio vs Costs", "CAGR vs Costs"),
265
+ horizontal_spacing=0.12,
266
+ )
267
+
268
+ # Sharpe trace
269
+ fig.add_trace(
270
+ go.Scatter(
271
+ x=cost_levels,
272
+ y=sharpe_values,
273
+ mode="lines+markers",
274
+ name="Sharpe Ratio",
275
+ line={"color": colors[0], "width": 3},
276
+ marker={"size": 10},
277
+ hovertemplate="Cost: %{x:.1f} bps<br>Sharpe: %{y:.2f}<extra></extra>",
278
+ ),
279
+ row=1,
280
+ col=1,
281
+ )
282
+
283
+ # Add zero line for Sharpe
284
+ fig.add_hline(
285
+ y=0,
286
+ line_dash="dash",
287
+ line_color="gray",
288
+ row=1,
289
+ col=1,
290
+ )
291
+
292
+ # CAGR trace
293
+ fig.add_trace(
294
+ go.Scatter(
295
+ x=cost_levels,
296
+ y=cagr_values,
297
+ mode="lines+markers",
298
+ name="CAGR (%)",
299
+ line={"color": colors[1] if len(colors) > 1 else colors[0], "width": 3},
300
+ marker={"size": 10},
301
+ hovertemplate="Cost: %{x:.1f} bps<br>CAGR: %{y:.1f}%<extra></extra>",
302
+ ),
303
+ row=1,
304
+ col=2,
305
+ )
306
+
307
+ # Add zero line for CAGR
308
+ fig.add_hline(
309
+ y=0,
310
+ line_dash="dash",
311
+ line_color="gray",
312
+ row=1,
313
+ col=2,
314
+ )
315
+
316
+ # Find breakeven point (where Sharpe crosses zero)
317
+ if show_breakeven:
318
+ for i in range(len(sharpe_values) - 1):
319
+ if sharpe_values[i] > 0 and sharpe_values[i + 1] <= 0:
320
+ # Linear interpolation
321
+ breakeven = cost_levels[i] + (
322
+ (0 - sharpe_values[i])
323
+ / (sharpe_values[i + 1] - sharpe_values[i])
324
+ * (cost_levels[i + 1] - cost_levels[i])
325
+ )
326
+ fig.add_vline(
327
+ x=breakeven,
328
+ line_dash="dot",
329
+ line_color="red",
330
+ annotation_text=f"Breakeven: {breakeven:.1f} bps",
331
+ annotation_position="top",
332
+ row=1,
333
+ col=1,
334
+ )
335
+ break
336
+
337
+ # Mark current cost level
338
+ if base_costs_bps in cost_levels:
339
+ idx = cost_levels.index(base_costs_bps)
340
+ fig.add_annotation(
341
+ x=base_costs_bps,
342
+ y=sharpe_values[idx],
343
+ text="Current",
344
+ showarrow=True,
345
+ arrowhead=2,
346
+ row=1,
347
+ col=1,
348
+ )
349
+
350
+ # Build layout
351
+ layout_updates = {
352
+ "title": {"text": title, "font": {"size": 18}},
353
+ "height": height,
354
+ "showlegend": False,
355
+ "xaxis": {"title": "Transaction Cost (bps)"},
356
+ "xaxis2": {"title": "Transaction Cost (bps)"},
357
+ "yaxis": {"title": "Sharpe Ratio"},
358
+ "yaxis2": {"title": "CAGR (%)"},
359
+ }
360
+ if width:
361
+ layout_updates["width"] = width
362
+
363
+ for key, value in theme_config["layout"].items():
364
+ if key not in layout_updates:
365
+ layout_updates[key] = value
366
+
367
+ fig.update_layout(**layout_updates)
368
+
369
+ return fig
370
+
371
+
372
+ def plot_cost_over_time(
373
+ dates: pl.Series | np.ndarray,
374
+ gross_returns: pl.Series | np.ndarray,
375
+ net_returns: pl.Series | np.ndarray,
376
+ rolling_window: int = 63,
377
+ title: str = "Cost Impact Over Time",
378
+ theme: str | None = None,
379
+ height: int = 450,
380
+ width: int | None = None,
381
+ ) -> go.Figure:
382
+ """Visualize how transaction costs impact returns over time.
383
+
384
+ Shows the difference between gross and net returns on a rolling basis,
385
+ helping identify periods of high cost impact.
386
+
387
+ Parameters
388
+ ----------
389
+ dates : pl.Series or np.ndarray
390
+ Date index
391
+ gross_returns : pl.Series or np.ndarray
392
+ Gross daily returns (before costs)
393
+ net_returns : pl.Series or np.ndarray
394
+ Net daily returns (after costs)
395
+ rolling_window : int
396
+ Rolling window for smoothing (default: 63 = ~3 months)
397
+ title : str
398
+ Chart title
399
+ theme : str, optional
400
+ Theme name
401
+ height : int
402
+ Figure height in pixels
403
+ width : int, optional
404
+ Figure width in pixels
405
+
406
+ Returns
407
+ -------
408
+ go.Figure
409
+ Plotly figure with rolling cost impact
410
+ """
411
+ import polars as pl
412
+
413
+ theme_config = get_theme_config(theme)
414
+ colors = theme_config["colorway"]
415
+
416
+ # Convert to numpy
417
+ if isinstance(dates, pl.Series):
418
+ dates_arr = dates.to_list()
419
+ else:
420
+ dates_arr = list(dates)
421
+
422
+ if isinstance(gross_returns, pl.Series):
423
+ gross_arr = gross_returns.to_numpy()
424
+ else:
425
+ gross_arr = np.asarray(gross_returns)
426
+
427
+ if isinstance(net_returns, pl.Series):
428
+ net_arr = net_returns.to_numpy()
429
+ else:
430
+ net_arr = np.asarray(net_returns)
431
+
432
+ # Calculate cost drag
433
+ cost_drag = gross_arr - net_arr
434
+
435
+ # Rolling metrics
436
+ def rolling_mean(arr: np.ndarray, window: int) -> np.ndarray:
437
+ """Simple rolling mean with edge handling."""
438
+ result = np.full(len(arr), np.nan)
439
+ for i in range(window - 1, len(arr)):
440
+ result[i] = np.mean(arr[i - window + 1 : i + 1])
441
+ return result
442
+
443
+ rolling_gross = rolling_mean(gross_arr, rolling_window) * 252 * 100
444
+ rolling_net = rolling_mean(net_arr, rolling_window) * 252 * 100
445
+ rolling_cost = rolling_mean(cost_drag, rolling_window) * 252 * 100
446
+
447
+ fig = go.Figure()
448
+
449
+ # Gross returns
450
+ fig.add_trace(
451
+ go.Scatter(
452
+ x=dates_arr,
453
+ y=rolling_gross,
454
+ name="Gross Returns (ann.)",
455
+ mode="lines",
456
+ line={"color": colors[0], "width": 2},
457
+ hovertemplate="%{x}<br>Gross: %{y:.1f}%<extra></extra>",
458
+ )
459
+ )
460
+
461
+ # Net returns
462
+ fig.add_trace(
463
+ go.Scatter(
464
+ x=dates_arr,
465
+ y=rolling_net,
466
+ name="Net Returns (ann.)",
467
+ mode="lines",
468
+ line={"color": colors[1] if len(colors) > 1 else colors[0], "width": 2},
469
+ hovertemplate="%{x}<br>Net: %{y:.1f}%<extra></extra>",
470
+ )
471
+ )
472
+
473
+ # Cost drag (as filled area)
474
+ fig.add_trace(
475
+ go.Scatter(
476
+ x=dates_arr,
477
+ y=rolling_cost,
478
+ name="Cost Drag (ann.)",
479
+ mode="lines",
480
+ fill="tozeroy",
481
+ line={"color": "rgba(239, 85, 59, 0.7)", "width": 1},
482
+ fillcolor="rgba(239, 85, 59, 0.3)",
483
+ hovertemplate="%{x}<br>Cost Drag: %{y:.1f}%<extra></extra>",
484
+ )
485
+ )
486
+
487
+ # Build layout
488
+ layout_updates = {
489
+ "title": {"text": title, "font": {"size": 18}},
490
+ "height": height,
491
+ "xaxis": {"title": "Date"},
492
+ "yaxis": {"title": "Annualized Return (%)"},
493
+ "legend": {"yanchor": "top", "y": 0.99, "xanchor": "left", "x": 0.01},
494
+ "hovermode": "x unified",
495
+ }
496
+ if width:
497
+ layout_updates["width"] = width
498
+
499
+ for key, value in theme_config["layout"].items():
500
+ if key not in layout_updates:
501
+ layout_updates[key] = value
502
+
503
+ fig.update_layout(**layout_updates)
504
+
505
+ return fig
506
+
507
+
508
+ def plot_cost_by_asset(
509
+ trades: pl.DataFrame,
510
+ top_n: int = 10,
511
+ cost_column: str = "cost",
512
+ symbol_column: str = "symbol",
513
+ sort_by: Literal["total", "per_trade", "percentage"] = "total",
514
+ title: str = "Transaction Costs by Asset",
515
+ theme: str | None = None,
516
+ height: int = 450,
517
+ width: int | None = None,
518
+ ) -> go.Figure:
519
+ """Show transaction cost breakdown by asset.
520
+
521
+ Helps identify which assets incur the highest costs and may need
522
+ different position sizing or execution strategies.
523
+
524
+ Parameters
525
+ ----------
526
+ trades : pl.DataFrame
527
+ Trade records with symbol and cost columns
528
+ top_n : int
529
+ Number of top assets to show
530
+ cost_column : str
531
+ Name of the cost column
532
+ symbol_column : str
533
+ Name of the symbol column
534
+ sort_by : {"total", "per_trade", "percentage"}
535
+ How to rank assets:
536
+ - "total": Total cost in dollars
537
+ - "per_trade": Average cost per trade
538
+ - "percentage": Cost as % of gross PnL
539
+ title : str
540
+ Chart title
541
+ theme : str, optional
542
+ Theme name
543
+ height : int
544
+ Figure height in pixels
545
+ width : int, optional
546
+ Figure width in pixels
547
+
548
+ Returns
549
+ -------
550
+ go.Figure
551
+ Plotly figure with cost breakdown by asset
552
+ """
553
+ import polars as pl
554
+
555
+ theme_config = get_theme_config(theme)
556
+ colors = theme_config["colorway"]
557
+
558
+ # Check if required columns exist
559
+ if cost_column not in trades.columns:
560
+ # Try to calculate cost from pnl columns
561
+ if "gross_pnl" in trades.columns and "net_pnl" in trades.columns:
562
+ trades = trades.with_columns((pl.col("gross_pnl") - pl.col("net_pnl")).alias("cost"))
563
+ cost_column = "cost"
564
+ else:
565
+ raise ValueError(f"Cost column '{cost_column}' not found and cannot be calculated")
566
+
567
+ if symbol_column not in trades.columns:
568
+ raise ValueError(f"Symbol column '{symbol_column}' not found")
569
+
570
+ # Aggregate by symbol
571
+ agg_cols = [
572
+ pl.col(cost_column).sum().alias("total_cost"),
573
+ pl.col(cost_column).mean().alias("avg_cost"),
574
+ pl.col(cost_column).count().alias("n_trades"),
575
+ ]
576
+
577
+ if "gross_pnl" in trades.columns:
578
+ agg_cols.append(pl.col("gross_pnl").sum().alias("total_gross"))
579
+
580
+ cost_by_symbol = trades.group_by(symbol_column).agg(agg_cols)
581
+
582
+ # Calculate percentage if we have gross PnL
583
+ if "total_gross" in cost_by_symbol.columns:
584
+ cost_by_symbol = cost_by_symbol.with_columns(
585
+ (pl.col("total_cost") / pl.col("total_gross").abs() * 100).alias("cost_pct")
586
+ )
587
+
588
+ # Sort based on criteria
589
+ if sort_by == "total":
590
+ cost_by_symbol = cost_by_symbol.sort("total_cost", descending=True)
591
+ elif sort_by == "per_trade":
592
+ cost_by_symbol = cost_by_symbol.sort("avg_cost", descending=True)
593
+ elif sort_by == "percentage" and "cost_pct" in cost_by_symbol.columns:
594
+ cost_by_symbol = cost_by_symbol.sort("cost_pct", descending=True)
595
+
596
+ # Take top N
597
+ top_assets = cost_by_symbol.head(top_n)
598
+
599
+ symbols = top_assets[symbol_column].to_list()
600
+ total_costs = top_assets["total_cost"].to_list()
601
+ n_trades = top_assets["n_trades"].to_list()
602
+
603
+ # Determine what to show on secondary axis
604
+ show_pct = "cost_pct" in top_assets.columns and sort_by == "percentage"
605
+
606
+ if show_pct:
607
+ secondary_values = top_assets["cost_pct"].to_list()
608
+ secondary_name = "Cost %"
609
+ secondary_format = ".1f"
610
+ else:
611
+ secondary_values = [c / n for c, n in zip(total_costs, n_trades)]
612
+ secondary_name = "Avg/Trade"
613
+ secondary_format = "$,.0f"
614
+
615
+ fig = make_subplots(specs=[[{"secondary_y": True}]])
616
+
617
+ # Bar chart for total costs
618
+ fig.add_trace(
619
+ go.Bar(
620
+ x=symbols,
621
+ y=total_costs,
622
+ name="Total Cost",
623
+ marker_color=colors[0],
624
+ hovertemplate="%{x}<br>Total: $%{y:,.0f}<extra></extra>",
625
+ ),
626
+ secondary_y=False,
627
+ )
628
+
629
+ # Line for secondary metric
630
+ fig.add_trace(
631
+ go.Scatter(
632
+ x=symbols,
633
+ y=secondary_values,
634
+ name=secondary_name,
635
+ mode="lines+markers",
636
+ line={"color": colors[1] if len(colors) > 1 else "red", "width": 2},
637
+ marker={"size": 8},
638
+ hovertemplate=f"%{{x}}<br>{secondary_name}: %{{y:{secondary_format}}}<extra></extra>",
639
+ ),
640
+ secondary_y=True,
641
+ )
642
+
643
+ # Build layout
644
+ layout_updates = {
645
+ "title": {"text": title, "font": {"size": 18}},
646
+ "height": height,
647
+ "xaxis": {"title": "Asset", "tickangle": -45},
648
+ "yaxis": {"title": "Total Cost ($)", "tickformat": "$,.0f"},
649
+ "legend": {"yanchor": "top", "y": 0.99, "xanchor": "right", "x": 0.99},
650
+ "bargap": 0.3,
651
+ }
652
+ if width:
653
+ layout_updates["width"] = width
654
+
655
+ for key, value in theme_config["layout"].items():
656
+ if key not in layout_updates:
657
+ layout_updates[key] = value
658
+
659
+ fig.update_layout(**layout_updates)
660
+
661
+ # Update secondary y-axis
662
+ if show_pct:
663
+ fig.update_yaxes(title_text="Cost (% of Gross)", tickformat=".1f%", secondary_y=True)
664
+ else:
665
+ fig.update_yaxes(title_text="Avg Cost/Trade ($)", tickformat="$,.0f", secondary_y=True)
666
+
667
+ return fig
668
+
669
+
670
+ def plot_cost_pie(
671
+ commission: float,
672
+ slippage: float,
673
+ other_costs: dict[str, float] | None = None,
674
+ title: str = "Cost Breakdown",
675
+ theme: str | None = None,
676
+ height: int = 400,
677
+ width: int | None = None,
678
+ ) -> go.Figure:
679
+ """Create a pie chart showing the breakdown of transaction costs.
680
+
681
+ Parameters
682
+ ----------
683
+ commission : float
684
+ Commission costs
685
+ slippage : float
686
+ Slippage costs
687
+ other_costs : dict[str, float], optional
688
+ Additional cost categories
689
+ title : str
690
+ Chart title
691
+ theme : str, optional
692
+ Theme name
693
+ height : int
694
+ Figure height in pixels
695
+ width : int, optional
696
+ Figure width in pixels
697
+
698
+ Returns
699
+ -------
700
+ go.Figure
701
+ Plotly pie chart figure
702
+ """
703
+ theme_config = get_theme_config(theme)
704
+ colors = theme_config["colorway"]
705
+
706
+ # Build labels and values
707
+ labels = ["Commission", "Slippage"]
708
+ values = [abs(commission), abs(slippage)]
709
+
710
+ if other_costs:
711
+ for name, cost in other_costs.items():
712
+ labels.append(name)
713
+ values.append(abs(cost))
714
+
715
+ # Calculate percentages for text
716
+ total = sum(values)
717
+ text_info = [f"${v:,.0f}<br>({v / total * 100:.1f}%)" for v in values]
718
+
719
+ fig = go.Figure(
720
+ go.Pie(
721
+ labels=labels,
722
+ values=values,
723
+ text=text_info,
724
+ textinfo="text",
725
+ hovertemplate="%{label}<br>$%{value:,.0f}<br>%{percent}<extra></extra>",
726
+ marker={"colors": colors[: len(labels)]},
727
+ hole=0.4, # Donut chart
728
+ )
729
+ )
730
+
731
+ # Add total in center
732
+ fig.add_annotation(
733
+ text=f"Total<br>${total:,.0f}",
734
+ x=0.5,
735
+ y=0.5,
736
+ font={"size": 16},
737
+ showarrow=False,
738
+ )
739
+
740
+ # Build layout
741
+ layout_updates = {
742
+ "title": {"text": title, "font": {"size": 18}},
743
+ "height": height,
744
+ "showlegend": True,
745
+ "legend": {
746
+ "orientation": "h",
747
+ "yanchor": "bottom",
748
+ "y": -0.1,
749
+ "xanchor": "center",
750
+ "x": 0.5,
751
+ },
752
+ }
753
+ if width:
754
+ layout_updates["width"] = width
755
+
756
+ for key, value in theme_config["layout"].items():
757
+ if key not in layout_updates:
758
+ layout_updates[key] = value
759
+
760
+ fig.update_layout(**layout_updates)
761
+
762
+ return fig