ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,762 @@
|
|
|
1
|
+
"""Cost attribution visualizations for backtest analysis.
|
|
2
|
+
|
|
3
|
+
Provides interactive Plotly visualizations for understanding
|
|
4
|
+
the impact of transaction costs on strategy performance.
|
|
5
|
+
|
|
6
|
+
Key visualizations:
|
|
7
|
+
- Cost waterfall (Gross → Commission → Slippage → Net)
|
|
8
|
+
- Cost sensitivity analysis (Sharpe degradation as costs increase)
|
|
9
|
+
- Cost over time (rolling cost impact)
|
|
10
|
+
- Cost by asset (identify high-cost positions)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, Literal
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import plotly.graph_objects as go
|
|
19
|
+
from plotly.subplots import make_subplots
|
|
20
|
+
|
|
21
|
+
from ml4t.diagnostic.visualization.core import get_theme_config
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
import polars as pl
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def plot_cost_waterfall(
|
|
28
|
+
gross_pnl: float,
|
|
29
|
+
commission: float,
|
|
30
|
+
slippage: float,
|
|
31
|
+
net_pnl: float | None = None,
|
|
32
|
+
other_costs: dict[str, float] | None = None,
|
|
33
|
+
title: str = "Cost Attribution Waterfall",
|
|
34
|
+
show_percentages: bool = True,
|
|
35
|
+
theme: str | None = None,
|
|
36
|
+
height: int = 500,
|
|
37
|
+
width: int | None = None,
|
|
38
|
+
) -> go.Figure:
|
|
39
|
+
"""Create a waterfall chart showing gross-to-net PnL decomposition.
|
|
40
|
+
|
|
41
|
+
Visualizes how transaction costs (commission, slippage) erode
|
|
42
|
+
gross trading profits into net returns.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
gross_pnl : float
|
|
47
|
+
Gross profit/loss before costs
|
|
48
|
+
commission : float
|
|
49
|
+
Total commission costs (should be positive, will be shown as negative)
|
|
50
|
+
slippage : float
|
|
51
|
+
Total slippage costs (should be positive, will be shown as negative)
|
|
52
|
+
net_pnl : float, optional
|
|
53
|
+
Net PnL after all costs. If not provided, calculated from inputs.
|
|
54
|
+
other_costs : dict[str, float], optional
|
|
55
|
+
Additional cost categories (e.g., {"Financing": 500, "Fees": 200})
|
|
56
|
+
title : str
|
|
57
|
+
Chart title
|
|
58
|
+
show_percentages : bool
|
|
59
|
+
Whether to show cost as percentage of gross
|
|
60
|
+
theme : str, optional
|
|
61
|
+
Theme name (default, dark, print, presentation)
|
|
62
|
+
height : int
|
|
63
|
+
Figure height in pixels
|
|
64
|
+
width : int, optional
|
|
65
|
+
Figure width in pixels
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
go.Figure
|
|
70
|
+
Plotly figure with waterfall chart
|
|
71
|
+
|
|
72
|
+
Examples
|
|
73
|
+
--------
|
|
74
|
+
>>> fig = plot_cost_waterfall(
|
|
75
|
+
... gross_pnl=100000,
|
|
76
|
+
... commission=2500,
|
|
77
|
+
... slippage=1500,
|
|
78
|
+
... )
|
|
79
|
+
>>> fig.show()
|
|
80
|
+
"""
|
|
81
|
+
theme_config = get_theme_config(theme)
|
|
82
|
+
|
|
83
|
+
# Build cost categories
|
|
84
|
+
labels = ["Gross PnL"]
|
|
85
|
+
values = [gross_pnl]
|
|
86
|
+
measures = ["absolute"]
|
|
87
|
+
|
|
88
|
+
# Add commission
|
|
89
|
+
labels.append("Commission")
|
|
90
|
+
values.append(-abs(commission))
|
|
91
|
+
measures.append("relative")
|
|
92
|
+
|
|
93
|
+
# Add slippage
|
|
94
|
+
labels.append("Slippage")
|
|
95
|
+
values.append(-abs(slippage))
|
|
96
|
+
measures.append("relative")
|
|
97
|
+
|
|
98
|
+
# Add other costs if provided
|
|
99
|
+
if other_costs:
|
|
100
|
+
for name, cost in other_costs.items():
|
|
101
|
+
labels.append(name)
|
|
102
|
+
values.append(-abs(cost))
|
|
103
|
+
measures.append("relative")
|
|
104
|
+
|
|
105
|
+
# Calculate net PnL
|
|
106
|
+
if net_pnl is None:
|
|
107
|
+
total_costs = commission + slippage
|
|
108
|
+
if other_costs:
|
|
109
|
+
total_costs += sum(other_costs.values())
|
|
110
|
+
net_pnl = gross_pnl - total_costs
|
|
111
|
+
|
|
112
|
+
labels.append("Net PnL")
|
|
113
|
+
values.append(net_pnl)
|
|
114
|
+
measures.append("total")
|
|
115
|
+
|
|
116
|
+
# Create hover text with percentages
|
|
117
|
+
if show_percentages and gross_pnl != 0:
|
|
118
|
+
text = [f"${gross_pnl:,.0f}"]
|
|
119
|
+
for val in values[1:-1]:
|
|
120
|
+
pct = abs(val) / abs(gross_pnl) * 100
|
|
121
|
+
text.append(f"${val:,.0f} ({pct:.1f}%)")
|
|
122
|
+
text.append(f"${net_pnl:,.0f}")
|
|
123
|
+
else:
|
|
124
|
+
text = [f"${v:,.0f}" for v in values]
|
|
125
|
+
|
|
126
|
+
# Determine colors
|
|
127
|
+
colors = theme_config["colorway"]
|
|
128
|
+
increasing_color = colors[0] # Usually green/blue
|
|
129
|
+
decreasing_color = colors[1] if len(colors) > 1 else "#EF553B" # Red for costs
|
|
130
|
+
totals_color = colors[2] if len(colors) > 2 else "#636EFA" # Blue for totals
|
|
131
|
+
|
|
132
|
+
fig = go.Figure(
|
|
133
|
+
go.Waterfall(
|
|
134
|
+
name="Cost Attribution",
|
|
135
|
+
orientation="v",
|
|
136
|
+
x=labels,
|
|
137
|
+
y=values,
|
|
138
|
+
measure=measures,
|
|
139
|
+
text=text,
|
|
140
|
+
textposition="outside",
|
|
141
|
+
increasing={"marker": {"color": increasing_color}},
|
|
142
|
+
decreasing={"marker": {"color": decreasing_color}},
|
|
143
|
+
totals={"marker": {"color": totals_color}},
|
|
144
|
+
connector={"line": {"color": "rgba(128, 128, 128, 0.5)", "width": 2}},
|
|
145
|
+
)
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Build layout
|
|
149
|
+
layout_updates = {
|
|
150
|
+
"title": {"text": title, "font": {"size": 18}},
|
|
151
|
+
"height": height,
|
|
152
|
+
"yaxis": {"title": "PnL ($)", "tickformat": "$,.0f"},
|
|
153
|
+
"showlegend": False,
|
|
154
|
+
}
|
|
155
|
+
if width:
|
|
156
|
+
layout_updates["width"] = width
|
|
157
|
+
|
|
158
|
+
# Merge theme layout without overwriting explicit settings
|
|
159
|
+
for key, value in theme_config["layout"].items():
|
|
160
|
+
if key not in layout_updates:
|
|
161
|
+
layout_updates[key] = value
|
|
162
|
+
|
|
163
|
+
fig.update_layout(**layout_updates)
|
|
164
|
+
|
|
165
|
+
return fig
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def plot_cost_sensitivity(
|
|
169
|
+
returns: pl.Series | np.ndarray,
|
|
170
|
+
base_costs_bps: float = 10.0,
|
|
171
|
+
cost_multipliers: list[float] | None = None,
|
|
172
|
+
trades_per_year: int = 252,
|
|
173
|
+
risk_free_rate: float = 0.0,
|
|
174
|
+
title: str = "Cost Sensitivity Analysis",
|
|
175
|
+
show_breakeven: bool = True,
|
|
176
|
+
theme: str | None = None,
|
|
177
|
+
height: int = 500,
|
|
178
|
+
width: int | None = None,
|
|
179
|
+
) -> go.Figure:
|
|
180
|
+
"""Analyze how Sharpe ratio degrades as transaction costs increase.
|
|
181
|
+
|
|
182
|
+
Shows the sensitivity of risk-adjusted returns to transaction costs,
|
|
183
|
+
helping identify the breakeven point where strategy becomes unprofitable.
|
|
184
|
+
|
|
185
|
+
Parameters
|
|
186
|
+
----------
|
|
187
|
+
returns : pl.Series or np.ndarray
|
|
188
|
+
Gross daily returns (before costs)
|
|
189
|
+
base_costs_bps : float
|
|
190
|
+
Base transaction cost in basis points (e.g., 10 = 0.1%)
|
|
191
|
+
cost_multipliers : list[float], optional
|
|
192
|
+
Multipliers to test (default: [0, 0.5, 1, 1.5, 2, 3, 5])
|
|
193
|
+
trades_per_year : int
|
|
194
|
+
Estimated number of trades per year for cost impact
|
|
195
|
+
risk_free_rate : float
|
|
196
|
+
Annual risk-free rate for Sharpe calculation
|
|
197
|
+
title : str
|
|
198
|
+
Chart title
|
|
199
|
+
show_breakeven : bool
|
|
200
|
+
Whether to annotate the breakeven cost level
|
|
201
|
+
theme : str, optional
|
|
202
|
+
Theme name
|
|
203
|
+
height : int
|
|
204
|
+
Figure height in pixels
|
|
205
|
+
width : int, optional
|
|
206
|
+
Figure width in pixels
|
|
207
|
+
|
|
208
|
+
Returns
|
|
209
|
+
-------
|
|
210
|
+
go.Figure
|
|
211
|
+
Plotly figure with cost sensitivity chart
|
|
212
|
+
"""
|
|
213
|
+
import polars as pl
|
|
214
|
+
|
|
215
|
+
theme_config = get_theme_config(theme)
|
|
216
|
+
|
|
217
|
+
# Convert to numpy
|
|
218
|
+
if isinstance(returns, pl.Series):
|
|
219
|
+
returns_arr = returns.to_numpy()
|
|
220
|
+
else:
|
|
221
|
+
returns_arr = np.asarray(returns)
|
|
222
|
+
|
|
223
|
+
# Default multipliers
|
|
224
|
+
if cost_multipliers is None:
|
|
225
|
+
cost_multipliers = [0, 0.5, 1, 1.5, 2, 3, 5]
|
|
226
|
+
|
|
227
|
+
# Calculate metrics at each cost level
|
|
228
|
+
cost_levels = []
|
|
229
|
+
sharpe_values = []
|
|
230
|
+
cagr_values = []
|
|
231
|
+
|
|
232
|
+
gross_mean = np.mean(returns_arr)
|
|
233
|
+
gross_std = np.std(returns_arr, ddof=1)
|
|
234
|
+
|
|
235
|
+
for mult in cost_multipliers:
|
|
236
|
+
# Cost per trade in decimal
|
|
237
|
+
cost_per_trade = (base_costs_bps * mult) / 10000
|
|
238
|
+
|
|
239
|
+
# Estimate daily cost drag (assuming uniform trading)
|
|
240
|
+
daily_cost_drag = cost_per_trade * (trades_per_year / 252)
|
|
241
|
+
|
|
242
|
+
# Net returns
|
|
243
|
+
net_mean = gross_mean - daily_cost_drag
|
|
244
|
+
|
|
245
|
+
# Calculate Sharpe
|
|
246
|
+
if gross_std > 0:
|
|
247
|
+
sharpe = (net_mean - risk_free_rate / 252) / gross_std * np.sqrt(252)
|
|
248
|
+
else:
|
|
249
|
+
sharpe = 0
|
|
250
|
+
|
|
251
|
+
# Calculate CAGR (approximate)
|
|
252
|
+
cagr = (1 + net_mean) ** 252 - 1
|
|
253
|
+
|
|
254
|
+
cost_levels.append(base_costs_bps * mult)
|
|
255
|
+
sharpe_values.append(sharpe)
|
|
256
|
+
cagr_values.append(cagr * 100) # As percentage
|
|
257
|
+
|
|
258
|
+
colors = theme_config["colorway"]
|
|
259
|
+
|
|
260
|
+
# Create subplot with Sharpe and CAGR
|
|
261
|
+
fig = make_subplots(
|
|
262
|
+
rows=1,
|
|
263
|
+
cols=2,
|
|
264
|
+
subplot_titles=("Sharpe Ratio vs Costs", "CAGR vs Costs"),
|
|
265
|
+
horizontal_spacing=0.12,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Sharpe trace
|
|
269
|
+
fig.add_trace(
|
|
270
|
+
go.Scatter(
|
|
271
|
+
x=cost_levels,
|
|
272
|
+
y=sharpe_values,
|
|
273
|
+
mode="lines+markers",
|
|
274
|
+
name="Sharpe Ratio",
|
|
275
|
+
line={"color": colors[0], "width": 3},
|
|
276
|
+
marker={"size": 10},
|
|
277
|
+
hovertemplate="Cost: %{x:.1f} bps<br>Sharpe: %{y:.2f}<extra></extra>",
|
|
278
|
+
),
|
|
279
|
+
row=1,
|
|
280
|
+
col=1,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# Add zero line for Sharpe
|
|
284
|
+
fig.add_hline(
|
|
285
|
+
y=0,
|
|
286
|
+
line_dash="dash",
|
|
287
|
+
line_color="gray",
|
|
288
|
+
row=1,
|
|
289
|
+
col=1,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# CAGR trace
|
|
293
|
+
fig.add_trace(
|
|
294
|
+
go.Scatter(
|
|
295
|
+
x=cost_levels,
|
|
296
|
+
y=cagr_values,
|
|
297
|
+
mode="lines+markers",
|
|
298
|
+
name="CAGR (%)",
|
|
299
|
+
line={"color": colors[1] if len(colors) > 1 else colors[0], "width": 3},
|
|
300
|
+
marker={"size": 10},
|
|
301
|
+
hovertemplate="Cost: %{x:.1f} bps<br>CAGR: %{y:.1f}%<extra></extra>",
|
|
302
|
+
),
|
|
303
|
+
row=1,
|
|
304
|
+
col=2,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Add zero line for CAGR
|
|
308
|
+
fig.add_hline(
|
|
309
|
+
y=0,
|
|
310
|
+
line_dash="dash",
|
|
311
|
+
line_color="gray",
|
|
312
|
+
row=1,
|
|
313
|
+
col=2,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Find breakeven point (where Sharpe crosses zero)
|
|
317
|
+
if show_breakeven:
|
|
318
|
+
for i in range(len(sharpe_values) - 1):
|
|
319
|
+
if sharpe_values[i] > 0 and sharpe_values[i + 1] <= 0:
|
|
320
|
+
# Linear interpolation
|
|
321
|
+
breakeven = cost_levels[i] + (
|
|
322
|
+
(0 - sharpe_values[i])
|
|
323
|
+
/ (sharpe_values[i + 1] - sharpe_values[i])
|
|
324
|
+
* (cost_levels[i + 1] - cost_levels[i])
|
|
325
|
+
)
|
|
326
|
+
fig.add_vline(
|
|
327
|
+
x=breakeven,
|
|
328
|
+
line_dash="dot",
|
|
329
|
+
line_color="red",
|
|
330
|
+
annotation_text=f"Breakeven: {breakeven:.1f} bps",
|
|
331
|
+
annotation_position="top",
|
|
332
|
+
row=1,
|
|
333
|
+
col=1,
|
|
334
|
+
)
|
|
335
|
+
break
|
|
336
|
+
|
|
337
|
+
# Mark current cost level
|
|
338
|
+
if base_costs_bps in cost_levels:
|
|
339
|
+
idx = cost_levels.index(base_costs_bps)
|
|
340
|
+
fig.add_annotation(
|
|
341
|
+
x=base_costs_bps,
|
|
342
|
+
y=sharpe_values[idx],
|
|
343
|
+
text="Current",
|
|
344
|
+
showarrow=True,
|
|
345
|
+
arrowhead=2,
|
|
346
|
+
row=1,
|
|
347
|
+
col=1,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# Build layout
|
|
351
|
+
layout_updates = {
|
|
352
|
+
"title": {"text": title, "font": {"size": 18}},
|
|
353
|
+
"height": height,
|
|
354
|
+
"showlegend": False,
|
|
355
|
+
"xaxis": {"title": "Transaction Cost (bps)"},
|
|
356
|
+
"xaxis2": {"title": "Transaction Cost (bps)"},
|
|
357
|
+
"yaxis": {"title": "Sharpe Ratio"},
|
|
358
|
+
"yaxis2": {"title": "CAGR (%)"},
|
|
359
|
+
}
|
|
360
|
+
if width:
|
|
361
|
+
layout_updates["width"] = width
|
|
362
|
+
|
|
363
|
+
for key, value in theme_config["layout"].items():
|
|
364
|
+
if key not in layout_updates:
|
|
365
|
+
layout_updates[key] = value
|
|
366
|
+
|
|
367
|
+
fig.update_layout(**layout_updates)
|
|
368
|
+
|
|
369
|
+
return fig
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def plot_cost_over_time(
|
|
373
|
+
dates: pl.Series | np.ndarray,
|
|
374
|
+
gross_returns: pl.Series | np.ndarray,
|
|
375
|
+
net_returns: pl.Series | np.ndarray,
|
|
376
|
+
rolling_window: int = 63,
|
|
377
|
+
title: str = "Cost Impact Over Time",
|
|
378
|
+
theme: str | None = None,
|
|
379
|
+
height: int = 450,
|
|
380
|
+
width: int | None = None,
|
|
381
|
+
) -> go.Figure:
|
|
382
|
+
"""Visualize how transaction costs impact returns over time.
|
|
383
|
+
|
|
384
|
+
Shows the difference between gross and net returns on a rolling basis,
|
|
385
|
+
helping identify periods of high cost impact.
|
|
386
|
+
|
|
387
|
+
Parameters
|
|
388
|
+
----------
|
|
389
|
+
dates : pl.Series or np.ndarray
|
|
390
|
+
Date index
|
|
391
|
+
gross_returns : pl.Series or np.ndarray
|
|
392
|
+
Gross daily returns (before costs)
|
|
393
|
+
net_returns : pl.Series or np.ndarray
|
|
394
|
+
Net daily returns (after costs)
|
|
395
|
+
rolling_window : int
|
|
396
|
+
Rolling window for smoothing (default: 63 = ~3 months)
|
|
397
|
+
title : str
|
|
398
|
+
Chart title
|
|
399
|
+
theme : str, optional
|
|
400
|
+
Theme name
|
|
401
|
+
height : int
|
|
402
|
+
Figure height in pixels
|
|
403
|
+
width : int, optional
|
|
404
|
+
Figure width in pixels
|
|
405
|
+
|
|
406
|
+
Returns
|
|
407
|
+
-------
|
|
408
|
+
go.Figure
|
|
409
|
+
Plotly figure with rolling cost impact
|
|
410
|
+
"""
|
|
411
|
+
import polars as pl
|
|
412
|
+
|
|
413
|
+
theme_config = get_theme_config(theme)
|
|
414
|
+
colors = theme_config["colorway"]
|
|
415
|
+
|
|
416
|
+
# Convert to numpy
|
|
417
|
+
if isinstance(dates, pl.Series):
|
|
418
|
+
dates_arr = dates.to_list()
|
|
419
|
+
else:
|
|
420
|
+
dates_arr = list(dates)
|
|
421
|
+
|
|
422
|
+
if isinstance(gross_returns, pl.Series):
|
|
423
|
+
gross_arr = gross_returns.to_numpy()
|
|
424
|
+
else:
|
|
425
|
+
gross_arr = np.asarray(gross_returns)
|
|
426
|
+
|
|
427
|
+
if isinstance(net_returns, pl.Series):
|
|
428
|
+
net_arr = net_returns.to_numpy()
|
|
429
|
+
else:
|
|
430
|
+
net_arr = np.asarray(net_returns)
|
|
431
|
+
|
|
432
|
+
# Calculate cost drag
|
|
433
|
+
cost_drag = gross_arr - net_arr
|
|
434
|
+
|
|
435
|
+
# Rolling metrics
|
|
436
|
+
def rolling_mean(arr: np.ndarray, window: int) -> np.ndarray:
|
|
437
|
+
"""Simple rolling mean with edge handling."""
|
|
438
|
+
result = np.full(len(arr), np.nan)
|
|
439
|
+
for i in range(window - 1, len(arr)):
|
|
440
|
+
result[i] = np.mean(arr[i - window + 1 : i + 1])
|
|
441
|
+
return result
|
|
442
|
+
|
|
443
|
+
rolling_gross = rolling_mean(gross_arr, rolling_window) * 252 * 100
|
|
444
|
+
rolling_net = rolling_mean(net_arr, rolling_window) * 252 * 100
|
|
445
|
+
rolling_cost = rolling_mean(cost_drag, rolling_window) * 252 * 100
|
|
446
|
+
|
|
447
|
+
fig = go.Figure()
|
|
448
|
+
|
|
449
|
+
# Gross returns
|
|
450
|
+
fig.add_trace(
|
|
451
|
+
go.Scatter(
|
|
452
|
+
x=dates_arr,
|
|
453
|
+
y=rolling_gross,
|
|
454
|
+
name="Gross Returns (ann.)",
|
|
455
|
+
mode="lines",
|
|
456
|
+
line={"color": colors[0], "width": 2},
|
|
457
|
+
hovertemplate="%{x}<br>Gross: %{y:.1f}%<extra></extra>",
|
|
458
|
+
)
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Net returns
|
|
462
|
+
fig.add_trace(
|
|
463
|
+
go.Scatter(
|
|
464
|
+
x=dates_arr,
|
|
465
|
+
y=rolling_net,
|
|
466
|
+
name="Net Returns (ann.)",
|
|
467
|
+
mode="lines",
|
|
468
|
+
line={"color": colors[1] if len(colors) > 1 else colors[0], "width": 2},
|
|
469
|
+
hovertemplate="%{x}<br>Net: %{y:.1f}%<extra></extra>",
|
|
470
|
+
)
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
# Cost drag (as filled area)
|
|
474
|
+
fig.add_trace(
|
|
475
|
+
go.Scatter(
|
|
476
|
+
x=dates_arr,
|
|
477
|
+
y=rolling_cost,
|
|
478
|
+
name="Cost Drag (ann.)",
|
|
479
|
+
mode="lines",
|
|
480
|
+
fill="tozeroy",
|
|
481
|
+
line={"color": "rgba(239, 85, 59, 0.7)", "width": 1},
|
|
482
|
+
fillcolor="rgba(239, 85, 59, 0.3)",
|
|
483
|
+
hovertemplate="%{x}<br>Cost Drag: %{y:.1f}%<extra></extra>",
|
|
484
|
+
)
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
# Build layout
|
|
488
|
+
layout_updates = {
|
|
489
|
+
"title": {"text": title, "font": {"size": 18}},
|
|
490
|
+
"height": height,
|
|
491
|
+
"xaxis": {"title": "Date"},
|
|
492
|
+
"yaxis": {"title": "Annualized Return (%)"},
|
|
493
|
+
"legend": {"yanchor": "top", "y": 0.99, "xanchor": "left", "x": 0.01},
|
|
494
|
+
"hovermode": "x unified",
|
|
495
|
+
}
|
|
496
|
+
if width:
|
|
497
|
+
layout_updates["width"] = width
|
|
498
|
+
|
|
499
|
+
for key, value in theme_config["layout"].items():
|
|
500
|
+
if key not in layout_updates:
|
|
501
|
+
layout_updates[key] = value
|
|
502
|
+
|
|
503
|
+
fig.update_layout(**layout_updates)
|
|
504
|
+
|
|
505
|
+
return fig
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
def plot_cost_by_asset(
|
|
509
|
+
trades: pl.DataFrame,
|
|
510
|
+
top_n: int = 10,
|
|
511
|
+
cost_column: str = "cost",
|
|
512
|
+
symbol_column: str = "symbol",
|
|
513
|
+
sort_by: Literal["total", "per_trade", "percentage"] = "total",
|
|
514
|
+
title: str = "Transaction Costs by Asset",
|
|
515
|
+
theme: str | None = None,
|
|
516
|
+
height: int = 450,
|
|
517
|
+
width: int | None = None,
|
|
518
|
+
) -> go.Figure:
|
|
519
|
+
"""Show transaction cost breakdown by asset.
|
|
520
|
+
|
|
521
|
+
Helps identify which assets incur the highest costs and may need
|
|
522
|
+
different position sizing or execution strategies.
|
|
523
|
+
|
|
524
|
+
Parameters
|
|
525
|
+
----------
|
|
526
|
+
trades : pl.DataFrame
|
|
527
|
+
Trade records with symbol and cost columns
|
|
528
|
+
top_n : int
|
|
529
|
+
Number of top assets to show
|
|
530
|
+
cost_column : str
|
|
531
|
+
Name of the cost column
|
|
532
|
+
symbol_column : str
|
|
533
|
+
Name of the symbol column
|
|
534
|
+
sort_by : {"total", "per_trade", "percentage"}
|
|
535
|
+
How to rank assets:
|
|
536
|
+
- "total": Total cost in dollars
|
|
537
|
+
- "per_trade": Average cost per trade
|
|
538
|
+
- "percentage": Cost as % of gross PnL
|
|
539
|
+
title : str
|
|
540
|
+
Chart title
|
|
541
|
+
theme : str, optional
|
|
542
|
+
Theme name
|
|
543
|
+
height : int
|
|
544
|
+
Figure height in pixels
|
|
545
|
+
width : int, optional
|
|
546
|
+
Figure width in pixels
|
|
547
|
+
|
|
548
|
+
Returns
|
|
549
|
+
-------
|
|
550
|
+
go.Figure
|
|
551
|
+
Plotly figure with cost breakdown by asset
|
|
552
|
+
"""
|
|
553
|
+
import polars as pl
|
|
554
|
+
|
|
555
|
+
theme_config = get_theme_config(theme)
|
|
556
|
+
colors = theme_config["colorway"]
|
|
557
|
+
|
|
558
|
+
# Check if required columns exist
|
|
559
|
+
if cost_column not in trades.columns:
|
|
560
|
+
# Try to calculate cost from pnl columns
|
|
561
|
+
if "gross_pnl" in trades.columns and "net_pnl" in trades.columns:
|
|
562
|
+
trades = trades.with_columns((pl.col("gross_pnl") - pl.col("net_pnl")).alias("cost"))
|
|
563
|
+
cost_column = "cost"
|
|
564
|
+
else:
|
|
565
|
+
raise ValueError(f"Cost column '{cost_column}' not found and cannot be calculated")
|
|
566
|
+
|
|
567
|
+
if symbol_column not in trades.columns:
|
|
568
|
+
raise ValueError(f"Symbol column '{symbol_column}' not found")
|
|
569
|
+
|
|
570
|
+
# Aggregate by symbol
|
|
571
|
+
agg_cols = [
|
|
572
|
+
pl.col(cost_column).sum().alias("total_cost"),
|
|
573
|
+
pl.col(cost_column).mean().alias("avg_cost"),
|
|
574
|
+
pl.col(cost_column).count().alias("n_trades"),
|
|
575
|
+
]
|
|
576
|
+
|
|
577
|
+
if "gross_pnl" in trades.columns:
|
|
578
|
+
agg_cols.append(pl.col("gross_pnl").sum().alias("total_gross"))
|
|
579
|
+
|
|
580
|
+
cost_by_symbol = trades.group_by(symbol_column).agg(agg_cols)
|
|
581
|
+
|
|
582
|
+
# Calculate percentage if we have gross PnL
|
|
583
|
+
if "total_gross" in cost_by_symbol.columns:
|
|
584
|
+
cost_by_symbol = cost_by_symbol.with_columns(
|
|
585
|
+
(pl.col("total_cost") / pl.col("total_gross").abs() * 100).alias("cost_pct")
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
# Sort based on criteria
|
|
589
|
+
if sort_by == "total":
|
|
590
|
+
cost_by_symbol = cost_by_symbol.sort("total_cost", descending=True)
|
|
591
|
+
elif sort_by == "per_trade":
|
|
592
|
+
cost_by_symbol = cost_by_symbol.sort("avg_cost", descending=True)
|
|
593
|
+
elif sort_by == "percentage" and "cost_pct" in cost_by_symbol.columns:
|
|
594
|
+
cost_by_symbol = cost_by_symbol.sort("cost_pct", descending=True)
|
|
595
|
+
|
|
596
|
+
# Take top N
|
|
597
|
+
top_assets = cost_by_symbol.head(top_n)
|
|
598
|
+
|
|
599
|
+
symbols = top_assets[symbol_column].to_list()
|
|
600
|
+
total_costs = top_assets["total_cost"].to_list()
|
|
601
|
+
n_trades = top_assets["n_trades"].to_list()
|
|
602
|
+
|
|
603
|
+
# Determine what to show on secondary axis
|
|
604
|
+
show_pct = "cost_pct" in top_assets.columns and sort_by == "percentage"
|
|
605
|
+
|
|
606
|
+
if show_pct:
|
|
607
|
+
secondary_values = top_assets["cost_pct"].to_list()
|
|
608
|
+
secondary_name = "Cost %"
|
|
609
|
+
secondary_format = ".1f"
|
|
610
|
+
else:
|
|
611
|
+
secondary_values = [c / n for c, n in zip(total_costs, n_trades)]
|
|
612
|
+
secondary_name = "Avg/Trade"
|
|
613
|
+
secondary_format = "$,.0f"
|
|
614
|
+
|
|
615
|
+
fig = make_subplots(specs=[[{"secondary_y": True}]])
|
|
616
|
+
|
|
617
|
+
# Bar chart for total costs
|
|
618
|
+
fig.add_trace(
|
|
619
|
+
go.Bar(
|
|
620
|
+
x=symbols,
|
|
621
|
+
y=total_costs,
|
|
622
|
+
name="Total Cost",
|
|
623
|
+
marker_color=colors[0],
|
|
624
|
+
hovertemplate="%{x}<br>Total: $%{y:,.0f}<extra></extra>",
|
|
625
|
+
),
|
|
626
|
+
secondary_y=False,
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
# Line for secondary metric
|
|
630
|
+
fig.add_trace(
|
|
631
|
+
go.Scatter(
|
|
632
|
+
x=symbols,
|
|
633
|
+
y=secondary_values,
|
|
634
|
+
name=secondary_name,
|
|
635
|
+
mode="lines+markers",
|
|
636
|
+
line={"color": colors[1] if len(colors) > 1 else "red", "width": 2},
|
|
637
|
+
marker={"size": 8},
|
|
638
|
+
hovertemplate=f"%{{x}}<br>{secondary_name}: %{{y:{secondary_format}}}<extra></extra>",
|
|
639
|
+
),
|
|
640
|
+
secondary_y=True,
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
# Build layout
|
|
644
|
+
layout_updates = {
|
|
645
|
+
"title": {"text": title, "font": {"size": 18}},
|
|
646
|
+
"height": height,
|
|
647
|
+
"xaxis": {"title": "Asset", "tickangle": -45},
|
|
648
|
+
"yaxis": {"title": "Total Cost ($)", "tickformat": "$,.0f"},
|
|
649
|
+
"legend": {"yanchor": "top", "y": 0.99, "xanchor": "right", "x": 0.99},
|
|
650
|
+
"bargap": 0.3,
|
|
651
|
+
}
|
|
652
|
+
if width:
|
|
653
|
+
layout_updates["width"] = width
|
|
654
|
+
|
|
655
|
+
for key, value in theme_config["layout"].items():
|
|
656
|
+
if key not in layout_updates:
|
|
657
|
+
layout_updates[key] = value
|
|
658
|
+
|
|
659
|
+
fig.update_layout(**layout_updates)
|
|
660
|
+
|
|
661
|
+
# Update secondary y-axis
|
|
662
|
+
if show_pct:
|
|
663
|
+
fig.update_yaxes(title_text="Cost (% of Gross)", tickformat=".1f%", secondary_y=True)
|
|
664
|
+
else:
|
|
665
|
+
fig.update_yaxes(title_text="Avg Cost/Trade ($)", tickformat="$,.0f", secondary_y=True)
|
|
666
|
+
|
|
667
|
+
return fig
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
def plot_cost_pie(
|
|
671
|
+
commission: float,
|
|
672
|
+
slippage: float,
|
|
673
|
+
other_costs: dict[str, float] | None = None,
|
|
674
|
+
title: str = "Cost Breakdown",
|
|
675
|
+
theme: str | None = None,
|
|
676
|
+
height: int = 400,
|
|
677
|
+
width: int | None = None,
|
|
678
|
+
) -> go.Figure:
|
|
679
|
+
"""Create a pie chart showing the breakdown of transaction costs.
|
|
680
|
+
|
|
681
|
+
Parameters
|
|
682
|
+
----------
|
|
683
|
+
commission : float
|
|
684
|
+
Commission costs
|
|
685
|
+
slippage : float
|
|
686
|
+
Slippage costs
|
|
687
|
+
other_costs : dict[str, float], optional
|
|
688
|
+
Additional cost categories
|
|
689
|
+
title : str
|
|
690
|
+
Chart title
|
|
691
|
+
theme : str, optional
|
|
692
|
+
Theme name
|
|
693
|
+
height : int
|
|
694
|
+
Figure height in pixels
|
|
695
|
+
width : int, optional
|
|
696
|
+
Figure width in pixels
|
|
697
|
+
|
|
698
|
+
Returns
|
|
699
|
+
-------
|
|
700
|
+
go.Figure
|
|
701
|
+
Plotly pie chart figure
|
|
702
|
+
"""
|
|
703
|
+
theme_config = get_theme_config(theme)
|
|
704
|
+
colors = theme_config["colorway"]
|
|
705
|
+
|
|
706
|
+
# Build labels and values
|
|
707
|
+
labels = ["Commission", "Slippage"]
|
|
708
|
+
values = [abs(commission), abs(slippage)]
|
|
709
|
+
|
|
710
|
+
if other_costs:
|
|
711
|
+
for name, cost in other_costs.items():
|
|
712
|
+
labels.append(name)
|
|
713
|
+
values.append(abs(cost))
|
|
714
|
+
|
|
715
|
+
# Calculate percentages for text
|
|
716
|
+
total = sum(values)
|
|
717
|
+
text_info = [f"${v:,.0f}<br>({v / total * 100:.1f}%)" for v in values]
|
|
718
|
+
|
|
719
|
+
fig = go.Figure(
|
|
720
|
+
go.Pie(
|
|
721
|
+
labels=labels,
|
|
722
|
+
values=values,
|
|
723
|
+
text=text_info,
|
|
724
|
+
textinfo="text",
|
|
725
|
+
hovertemplate="%{label}<br>$%{value:,.0f}<br>%{percent}<extra></extra>",
|
|
726
|
+
marker={"colors": colors[: len(labels)]},
|
|
727
|
+
hole=0.4, # Donut chart
|
|
728
|
+
)
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
# Add total in center
|
|
732
|
+
fig.add_annotation(
|
|
733
|
+
text=f"Total<br>${total:,.0f}",
|
|
734
|
+
x=0.5,
|
|
735
|
+
y=0.5,
|
|
736
|
+
font={"size": 16},
|
|
737
|
+
showarrow=False,
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
# Build layout
|
|
741
|
+
layout_updates = {
|
|
742
|
+
"title": {"text": title, "font": {"size": 18}},
|
|
743
|
+
"height": height,
|
|
744
|
+
"showlegend": True,
|
|
745
|
+
"legend": {
|
|
746
|
+
"orientation": "h",
|
|
747
|
+
"yanchor": "bottom",
|
|
748
|
+
"y": -0.1,
|
|
749
|
+
"xanchor": "center",
|
|
750
|
+
"x": 0.5,
|
|
751
|
+
},
|
|
752
|
+
}
|
|
753
|
+
if width:
|
|
754
|
+
layout_updates["width"] = width
|
|
755
|
+
|
|
756
|
+
for key, value in theme_config["layout"].items():
|
|
757
|
+
if key not in layout_updates:
|
|
758
|
+
layout_updates[key] = value
|
|
759
|
+
|
|
760
|
+
fig.update_layout(**layout_updates)
|
|
761
|
+
|
|
762
|
+
return fig
|