ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,603 @@
|
|
|
1
|
+
"""Multi-Signal Analysis Visualization Plots.
|
|
2
|
+
|
|
3
|
+
This module provides interactive Plotly visualizations for multi-signal analysis:
|
|
4
|
+
- plot_ic_ridge: IC density ridge plot showing distribution per signal
|
|
5
|
+
- plot_signal_ranking_bar: Horizontal bar chart of signals by metric
|
|
6
|
+
- plot_signal_correlation_heatmap: Signal correlation heatmap with clustering
|
|
7
|
+
- plot_pareto_frontier: Scatter plot with Pareto frontier highlighted
|
|
8
|
+
|
|
9
|
+
All plots follow the Focus+Context pattern for analyzing 50-200 signals:
|
|
10
|
+
- Focus: Selected/significant signals highlighted
|
|
11
|
+
- Context: All signals shown in background for comparison
|
|
12
|
+
|
|
13
|
+
References
|
|
14
|
+
----------
|
|
15
|
+
Tufte, E. (1983). "The Visual Display of Quantitative Information"
|
|
16
|
+
Few, S. (2012). "Show Me the Numbers"
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
from typing import TYPE_CHECKING
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
import plotly.graph_objects as go
|
|
25
|
+
from scipy.cluster.hierarchy import dendrogram, linkage
|
|
26
|
+
from scipy.spatial.distance import squareform
|
|
27
|
+
|
|
28
|
+
from ml4t.diagnostic.visualization.core import (
|
|
29
|
+
create_base_figure,
|
|
30
|
+
get_colorscale,
|
|
31
|
+
get_theme_config,
|
|
32
|
+
validate_theme,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING:
|
|
36
|
+
import polars as pl
|
|
37
|
+
|
|
38
|
+
from ml4t.diagnostic.results.multi_signal_results import MultiSignalSummary
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# =============================================================================
|
|
42
|
+
# IC Ridge Plot
|
|
43
|
+
# =============================================================================
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def plot_ic_ridge(
|
|
47
|
+
summary: MultiSignalSummary,
|
|
48
|
+
max_signals: int = 50,
|
|
49
|
+
sort_by: str = "ic_mean",
|
|
50
|
+
show_significance: bool = True,
|
|
51
|
+
theme: str | None = None,
|
|
52
|
+
width: int | None = None,
|
|
53
|
+
height: int | None = None,
|
|
54
|
+
) -> go.Figure:
|
|
55
|
+
"""IC density ridge plot showing IC distribution per signal.
|
|
56
|
+
|
|
57
|
+
Displays horizontal bars from IC 5th to 95th percentile with point at mean.
|
|
58
|
+
Color indicates FDR significance (green=significant, gray=not significant).
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
summary : MultiSignalSummary
|
|
63
|
+
Summary metrics from MultiSignalAnalysis.compute_summary()
|
|
64
|
+
max_signals : int, default 50
|
|
65
|
+
Maximum number of signals to display
|
|
66
|
+
sort_by : str, default "ic_mean"
|
|
67
|
+
Metric to sort signals by. Options: "ic_mean", "ic_ir", "ic_t_stat"
|
|
68
|
+
show_significance : bool, default True
|
|
69
|
+
Color by FDR significance
|
|
70
|
+
theme : str | None
|
|
71
|
+
Plot theme (default, dark, print, presentation)
|
|
72
|
+
width : int | None
|
|
73
|
+
Figure width in pixels
|
|
74
|
+
height : int | None
|
|
75
|
+
Figure height in pixels (auto-scaled by n_signals if None)
|
|
76
|
+
|
|
77
|
+
Returns
|
|
78
|
+
-------
|
|
79
|
+
go.Figure
|
|
80
|
+
Interactive Plotly figure
|
|
81
|
+
|
|
82
|
+
Examples
|
|
83
|
+
--------
|
|
84
|
+
>>> summary = analyzer.compute_summary()
|
|
85
|
+
>>> fig = plot_ic_ridge(summary, max_signals=30, sort_by="ic_ir")
|
|
86
|
+
>>> fig.show()
|
|
87
|
+
"""
|
|
88
|
+
theme = validate_theme(theme)
|
|
89
|
+
theme_config = get_theme_config(theme)
|
|
90
|
+
|
|
91
|
+
# Get DataFrame and sort
|
|
92
|
+
df = summary.get_dataframe()
|
|
93
|
+
|
|
94
|
+
if sort_by not in df.columns:
|
|
95
|
+
available = [c for c in df.columns if "ic" in c.lower()]
|
|
96
|
+
raise ValueError(f"Sort metric '{sort_by}' not found. Available: {available}")
|
|
97
|
+
|
|
98
|
+
# Sort and limit
|
|
99
|
+
df = df.sort(sort_by, descending=True).head(max_signals)
|
|
100
|
+
n_signals = len(df)
|
|
101
|
+
|
|
102
|
+
# Extract data
|
|
103
|
+
signal_names = df["signal_name"].to_list()
|
|
104
|
+
ic_means = df["ic_mean"].to_list() if "ic_mean" in df.columns else [0] * n_signals
|
|
105
|
+
|
|
106
|
+
# Get percentiles if available, otherwise use std
|
|
107
|
+
if "ic_p5" in df.columns and "ic_p95" in df.columns:
|
|
108
|
+
ic_lower = df["ic_p5"].to_list()
|
|
109
|
+
ic_upper = df["ic_p95"].to_list()
|
|
110
|
+
elif "ic_std" in df.columns:
|
|
111
|
+
ic_stds = df["ic_std"].to_list()
|
|
112
|
+
ic_lower = [m - 1.96 * s for m, s in zip(ic_means, ic_stds)]
|
|
113
|
+
ic_upper = [m + 1.96 * s for m, s in zip(ic_means, ic_stds)]
|
|
114
|
+
else:
|
|
115
|
+
ic_lower = ic_means
|
|
116
|
+
ic_upper = ic_means
|
|
117
|
+
|
|
118
|
+
# Get significance flags
|
|
119
|
+
if show_significance and "fdr_significant" in df.columns:
|
|
120
|
+
fdr_significant = df["fdr_significant"].to_list()
|
|
121
|
+
else:
|
|
122
|
+
fdr_significant = [False] * n_signals
|
|
123
|
+
|
|
124
|
+
# Colors: significant=green, not significant=gray
|
|
125
|
+
colors = [theme_config["colorway"][0] if sig else "#888888" for sig in fdr_significant]
|
|
126
|
+
|
|
127
|
+
# Calculate height based on number of signals
|
|
128
|
+
if height is None:
|
|
129
|
+
height = max(400, min(1200, n_signals * 25 + 100))
|
|
130
|
+
|
|
131
|
+
# Create figure
|
|
132
|
+
fig = create_base_figure(
|
|
133
|
+
title=f"IC Distribution by Signal (Top {n_signals} by {sort_by})",
|
|
134
|
+
xaxis_title="Information Coefficient",
|
|
135
|
+
yaxis_title="",
|
|
136
|
+
width=width or 800,
|
|
137
|
+
height=height,
|
|
138
|
+
theme=theme,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Add range bars (5th to 95th percentile)
|
|
142
|
+
for i, (name, lower, upper, mean, color) in enumerate(
|
|
143
|
+
zip(signal_names, ic_lower, ic_upper, ic_means, colors)
|
|
144
|
+
):
|
|
145
|
+
# Range line
|
|
146
|
+
fig.add_trace(
|
|
147
|
+
go.Scatter(
|
|
148
|
+
x=[lower, upper],
|
|
149
|
+
y=[name, name],
|
|
150
|
+
mode="lines",
|
|
151
|
+
line={"color": color, "width": 4},
|
|
152
|
+
showlegend=False,
|
|
153
|
+
hoverinfo="skip",
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Mean point
|
|
158
|
+
fig.add_trace(
|
|
159
|
+
go.Scatter(
|
|
160
|
+
x=[mean],
|
|
161
|
+
y=[name],
|
|
162
|
+
mode="markers",
|
|
163
|
+
marker={"size": 10, "color": color, "symbol": "diamond"},
|
|
164
|
+
name=name if i == 0 else None,
|
|
165
|
+
showlegend=False,
|
|
166
|
+
hovertemplate=(
|
|
167
|
+
f"<b>{name}</b><br>"
|
|
168
|
+
f"IC Mean: {mean:.4f}<br>"
|
|
169
|
+
f"IC Range: [{lower:.4f}, {upper:.4f}]"
|
|
170
|
+
"<extra></extra>"
|
|
171
|
+
),
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Add zero line
|
|
176
|
+
fig.add_vline(x=0, line_dash="dash", line_color="gray", opacity=0.5)
|
|
177
|
+
|
|
178
|
+
# Update layout for horizontal bar style
|
|
179
|
+
fig.update_layout(
|
|
180
|
+
yaxis={"categoryorder": "array", "categoryarray": signal_names[::-1]},
|
|
181
|
+
showlegend=False,
|
|
182
|
+
margin={"l": 200, "r": 50, "t": 60, "b": 50},
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return fig
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
# =============================================================================
|
|
189
|
+
# Signal Ranking Bar Chart
|
|
190
|
+
# =============================================================================
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def plot_signal_ranking_bar(
|
|
194
|
+
summary: MultiSignalSummary,
|
|
195
|
+
metric: str = "ic_ir",
|
|
196
|
+
top_n: int = 20,
|
|
197
|
+
color_by: str = "fdr_significant",
|
|
198
|
+
theme: str | None = None,
|
|
199
|
+
width: int | None = None,
|
|
200
|
+
height: int | None = None,
|
|
201
|
+
) -> go.Figure:
|
|
202
|
+
"""Horizontal bar chart of top signals by metric.
|
|
203
|
+
|
|
204
|
+
Parameters
|
|
205
|
+
----------
|
|
206
|
+
summary : MultiSignalSummary
|
|
207
|
+
Summary metrics from MultiSignalAnalysis.compute_summary()
|
|
208
|
+
metric : str, default "ic_ir"
|
|
209
|
+
Metric to rank by. Options: "ic_ir", "ic_mean", "ic_t_stat"
|
|
210
|
+
top_n : int, default 20
|
|
211
|
+
Number of top signals to display
|
|
212
|
+
color_by : str, default "fdr_significant"
|
|
213
|
+
How to color bars: "fdr_significant", "fwer_significant", or None
|
|
214
|
+
theme : str | None
|
|
215
|
+
Plot theme
|
|
216
|
+
width : int | None
|
|
217
|
+
Figure width in pixels
|
|
218
|
+
height : int | None
|
|
219
|
+
Figure height in pixels
|
|
220
|
+
|
|
221
|
+
Returns
|
|
222
|
+
-------
|
|
223
|
+
go.Figure
|
|
224
|
+
Interactive Plotly figure
|
|
225
|
+
|
|
226
|
+
Examples
|
|
227
|
+
--------
|
|
228
|
+
>>> fig = plot_signal_ranking_bar(summary, metric="ic_ir", top_n=15)
|
|
229
|
+
>>> fig.show()
|
|
230
|
+
"""
|
|
231
|
+
theme = validate_theme(theme)
|
|
232
|
+
theme_config = get_theme_config(theme)
|
|
233
|
+
|
|
234
|
+
# Get DataFrame and sort
|
|
235
|
+
df = summary.get_dataframe()
|
|
236
|
+
|
|
237
|
+
if metric not in df.columns:
|
|
238
|
+
raise ValueError(f"Metric '{metric}' not found. Available: {df.columns}")
|
|
239
|
+
|
|
240
|
+
df = df.sort(metric, descending=True).head(top_n)
|
|
241
|
+
|
|
242
|
+
signal_names = df["signal_name"].to_list()
|
|
243
|
+
values = df[metric].to_list()
|
|
244
|
+
|
|
245
|
+
# Determine colors
|
|
246
|
+
if color_by and color_by in df.columns:
|
|
247
|
+
significant = df[color_by].to_list()
|
|
248
|
+
colors = [theme_config["colorway"][0] if sig else "#CCCCCC" for sig in significant]
|
|
249
|
+
else:
|
|
250
|
+
colors = [theme_config["colorway"][0]] * len(signal_names)
|
|
251
|
+
|
|
252
|
+
# Calculate height
|
|
253
|
+
if height is None:
|
|
254
|
+
height = max(400, min(800, top_n * 30 + 100))
|
|
255
|
+
|
|
256
|
+
# Create figure
|
|
257
|
+
fig = create_base_figure(
|
|
258
|
+
title=f"Top {top_n} Signals by {metric.upper().replace('_', ' ')}",
|
|
259
|
+
xaxis_title=metric.upper().replace("_", " "),
|
|
260
|
+
yaxis_title="",
|
|
261
|
+
width=width or 700,
|
|
262
|
+
height=height,
|
|
263
|
+
theme=theme,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Add horizontal bars
|
|
267
|
+
fig.add_trace(
|
|
268
|
+
go.Bar(
|
|
269
|
+
x=values,
|
|
270
|
+
y=signal_names,
|
|
271
|
+
orientation="h",
|
|
272
|
+
marker={"color": colors},
|
|
273
|
+
text=[f"{v:.3f}" for v in values],
|
|
274
|
+
textposition="outside",
|
|
275
|
+
hovertemplate="<b>%{y}</b><br>%{x:.4f}<extra></extra>",
|
|
276
|
+
)
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Update layout
|
|
280
|
+
fig.update_layout(
|
|
281
|
+
yaxis={"categoryorder": "array", "categoryarray": signal_names[::-1]},
|
|
282
|
+
margin={"l": 200, "r": 80, "t": 60, "b": 50},
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
return fig
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
# =============================================================================
|
|
289
|
+
# Signal Correlation Heatmap
|
|
290
|
+
# =============================================================================
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def plot_signal_correlation_heatmap(
|
|
294
|
+
correlation_matrix: pl.DataFrame,
|
|
295
|
+
cluster: bool = True,
|
|
296
|
+
max_signals: int = 100,
|
|
297
|
+
theme: str | None = None,
|
|
298
|
+
width: int | None = None,
|
|
299
|
+
height: int | None = None,
|
|
300
|
+
) -> go.Figure:
|
|
301
|
+
"""Signal correlation heatmap with optional hierarchical clustering.
|
|
302
|
+
|
|
303
|
+
Reveals "100 signals = 3 unique bets" pattern through correlation analysis.
|
|
304
|
+
When clustering is enabled, reorders signals by dendrogram to show clusters.
|
|
305
|
+
|
|
306
|
+
Parameters
|
|
307
|
+
----------
|
|
308
|
+
correlation_matrix : pl.DataFrame
|
|
309
|
+
Signal correlation matrix from MultiSignalAnalysis.correlation_matrix()
|
|
310
|
+
cluster : bool, default True
|
|
311
|
+
Apply hierarchical clustering to reorder signals
|
|
312
|
+
max_signals : int, default 100
|
|
313
|
+
Maximum signals to display (limits browser memory)
|
|
314
|
+
theme : str | None
|
|
315
|
+
Plot theme
|
|
316
|
+
width : int | None
|
|
317
|
+
Figure width in pixels
|
|
318
|
+
height : int | None
|
|
319
|
+
Figure height in pixels
|
|
320
|
+
|
|
321
|
+
Returns
|
|
322
|
+
-------
|
|
323
|
+
go.Figure
|
|
324
|
+
Interactive Plotly figure
|
|
325
|
+
|
|
326
|
+
Examples
|
|
327
|
+
--------
|
|
328
|
+
>>> corr_matrix = analyzer.correlation_matrix()
|
|
329
|
+
>>> fig = plot_signal_correlation_heatmap(corr_matrix, cluster=True)
|
|
330
|
+
>>> fig.show()
|
|
331
|
+
"""
|
|
332
|
+
theme = validate_theme(theme)
|
|
333
|
+
get_theme_config(theme)
|
|
334
|
+
|
|
335
|
+
# Get signal names and correlation values
|
|
336
|
+
signal_names = correlation_matrix.columns
|
|
337
|
+
|
|
338
|
+
# Limit to max_signals (take first N)
|
|
339
|
+
if len(signal_names) > max_signals:
|
|
340
|
+
signal_names = signal_names[:max_signals]
|
|
341
|
+
correlation_matrix = correlation_matrix.select(signal_names)
|
|
342
|
+
# Filter rows as well
|
|
343
|
+
correlation_matrix = correlation_matrix.head(max_signals)
|
|
344
|
+
|
|
345
|
+
n_signals = len(signal_names)
|
|
346
|
+
|
|
347
|
+
# Convert to numpy for clustering
|
|
348
|
+
corr_values = correlation_matrix.to_numpy()
|
|
349
|
+
|
|
350
|
+
# Hierarchical clustering to reorder
|
|
351
|
+
if cluster and n_signals > 2:
|
|
352
|
+
# Convert correlation to distance (1 - abs(corr))
|
|
353
|
+
# Handle any NaN values
|
|
354
|
+
corr_clean = np.nan_to_num(corr_values, nan=0.0)
|
|
355
|
+
distance_matrix = 1 - np.abs(corr_clean)
|
|
356
|
+
|
|
357
|
+
# Ensure symmetry and proper diagonal
|
|
358
|
+
distance_matrix = (distance_matrix + distance_matrix.T) / 2
|
|
359
|
+
np.fill_diagonal(distance_matrix, 0)
|
|
360
|
+
|
|
361
|
+
# Clip to valid range
|
|
362
|
+
distance_matrix = np.clip(distance_matrix, 0, 2)
|
|
363
|
+
|
|
364
|
+
# Convert to condensed form and perform clustering
|
|
365
|
+
condensed = squareform(distance_matrix, checks=False)
|
|
366
|
+
linkage_matrix = linkage(condensed, method="average")
|
|
367
|
+
|
|
368
|
+
# Get leaf order from dendrogram
|
|
369
|
+
dend = dendrogram(linkage_matrix, no_plot=True)
|
|
370
|
+
order = dend["leaves"]
|
|
371
|
+
|
|
372
|
+
# Reorder signals and correlation matrix
|
|
373
|
+
signal_names = [signal_names[i] for i in order]
|
|
374
|
+
corr_values = corr_values[np.ix_(order, order)]
|
|
375
|
+
|
|
376
|
+
# Determine size
|
|
377
|
+
if width is None:
|
|
378
|
+
width = max(600, min(1000, n_signals * 10 + 200))
|
|
379
|
+
if height is None:
|
|
380
|
+
height = max(600, min(1000, n_signals * 10 + 200))
|
|
381
|
+
|
|
382
|
+
# Create heatmap
|
|
383
|
+
fig = go.Figure(
|
|
384
|
+
data=go.Heatmap(
|
|
385
|
+
z=corr_values,
|
|
386
|
+
x=signal_names,
|
|
387
|
+
y=signal_names,
|
|
388
|
+
colorscale=get_colorscale("rdbu"),
|
|
389
|
+
zmid=0,
|
|
390
|
+
zmin=-1,
|
|
391
|
+
zmax=1,
|
|
392
|
+
colorbar={"title": "Correlation", "tickformat": ".2f"},
|
|
393
|
+
hovertemplate=("<b>%{x}</b> vs <b>%{y}</b><br>Correlation: %{z:.3f}<extra></extra>"),
|
|
394
|
+
)
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
# Update layout
|
|
398
|
+
title = f"Signal Correlation Matrix ({n_signals} signals)"
|
|
399
|
+
if cluster:
|
|
400
|
+
title += " - Clustered"
|
|
401
|
+
|
|
402
|
+
fig.update_layout(
|
|
403
|
+
title=title,
|
|
404
|
+
width=width,
|
|
405
|
+
height=height,
|
|
406
|
+
xaxis={"tickangle": 45, "side": "bottom"},
|
|
407
|
+
yaxis={"autorange": "reversed"},
|
|
408
|
+
margin={"l": 150, "r": 50, "t": 60, "b": 150},
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
return fig
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
# =============================================================================
|
|
415
|
+
# Pareto Frontier Plot
|
|
416
|
+
# =============================================================================
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def plot_pareto_frontier(
|
|
420
|
+
summary: MultiSignalSummary,
|
|
421
|
+
x_metric: str = "turnover_mean",
|
|
422
|
+
y_metric: str = "ic_ir",
|
|
423
|
+
minimize_x: bool = True,
|
|
424
|
+
maximize_y: bool = True,
|
|
425
|
+
highlight_pareto: bool = True,
|
|
426
|
+
theme: str | None = None,
|
|
427
|
+
width: int | None = None,
|
|
428
|
+
height: int | None = None,
|
|
429
|
+
) -> go.Figure:
|
|
430
|
+
"""Scatter plot with Pareto frontier highlighted.
|
|
431
|
+
|
|
432
|
+
Shows all signals as points with Pareto-optimal signals connected by line.
|
|
433
|
+
Useful for identifying signals that offer best trade-offs between metrics.
|
|
434
|
+
|
|
435
|
+
Parameters
|
|
436
|
+
----------
|
|
437
|
+
summary : MultiSignalSummary
|
|
438
|
+
Summary metrics from MultiSignalAnalysis.compute_summary()
|
|
439
|
+
x_metric : str, default "turnover_mean"
|
|
440
|
+
Metric for x-axis (typically something to minimize)
|
|
441
|
+
y_metric : str, default "ic_ir"
|
|
442
|
+
Metric for y-axis (typically something to maximize)
|
|
443
|
+
minimize_x : bool, default True
|
|
444
|
+
If True, lower x values are better
|
|
445
|
+
maximize_y : bool, default True
|
|
446
|
+
If True, higher y values are better
|
|
447
|
+
highlight_pareto : bool, default True
|
|
448
|
+
Highlight Pareto-optimal signals
|
|
449
|
+
theme : str | None
|
|
450
|
+
Plot theme
|
|
451
|
+
width : int | None
|
|
452
|
+
Figure width in pixels
|
|
453
|
+
height : int | None
|
|
454
|
+
Figure height in pixels
|
|
455
|
+
|
|
456
|
+
Returns
|
|
457
|
+
-------
|
|
458
|
+
go.Figure
|
|
459
|
+
Interactive Plotly figure
|
|
460
|
+
|
|
461
|
+
Examples
|
|
462
|
+
--------
|
|
463
|
+
>>> fig = plot_pareto_frontier(summary, x_metric="turnover_mean", y_metric="ic_ir")
|
|
464
|
+
>>> fig.show()
|
|
465
|
+
"""
|
|
466
|
+
theme = validate_theme(theme)
|
|
467
|
+
theme_config = get_theme_config(theme)
|
|
468
|
+
|
|
469
|
+
df = summary.get_dataframe()
|
|
470
|
+
|
|
471
|
+
# Validate metrics
|
|
472
|
+
for m in [x_metric, y_metric]:
|
|
473
|
+
if m not in df.columns:
|
|
474
|
+
raise ValueError(f"Metric '{m}' not found. Available: {df.columns}")
|
|
475
|
+
|
|
476
|
+
signal_names = df["signal_name"].to_list()
|
|
477
|
+
x_values = df[x_metric].to_list()
|
|
478
|
+
y_values = df[y_metric].to_list()
|
|
479
|
+
|
|
480
|
+
# Identify Pareto frontier
|
|
481
|
+
pareto_mask = _compute_pareto_mask(x_values, y_values, minimize_x, maximize_y)
|
|
482
|
+
pareto_signals = [n for n, p in zip(signal_names, pareto_mask) if p]
|
|
483
|
+
|
|
484
|
+
# Colors: Pareto=primary color, others=gray
|
|
485
|
+
colors = [theme_config["colorway"][0] if p else "#CCCCCC" for p in pareto_mask]
|
|
486
|
+
|
|
487
|
+
# Create figure
|
|
488
|
+
fig = create_base_figure(
|
|
489
|
+
title=f"Signal Efficiency: {y_metric} vs {x_metric}",
|
|
490
|
+
xaxis_title=x_metric.upper().replace("_", " "),
|
|
491
|
+
yaxis_title=y_metric.upper().replace("_", " "),
|
|
492
|
+
width=width or 800,
|
|
493
|
+
height=height or 600,
|
|
494
|
+
theme=theme,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# Add all signals as scatter
|
|
498
|
+
fig.add_trace(
|
|
499
|
+
go.Scatter(
|
|
500
|
+
x=x_values,
|
|
501
|
+
y=y_values,
|
|
502
|
+
mode="markers",
|
|
503
|
+
marker={
|
|
504
|
+
"size": 10,
|
|
505
|
+
"color": colors,
|
|
506
|
+
"line": {"width": 1, "color": "white"},
|
|
507
|
+
},
|
|
508
|
+
text=signal_names,
|
|
509
|
+
hovertemplate=(
|
|
510
|
+
"<b>%{text}</b><br>"
|
|
511
|
+
f"{x_metric}: %{{x:.4f}}<br>"
|
|
512
|
+
f"{y_metric}: %{{y:.4f}}"
|
|
513
|
+
"<extra></extra>"
|
|
514
|
+
),
|
|
515
|
+
name="All Signals",
|
|
516
|
+
)
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
# Add Pareto frontier line
|
|
520
|
+
if highlight_pareto and len(pareto_signals) > 1:
|
|
521
|
+
# Get Pareto points and sort for line
|
|
522
|
+
pareto_x = [x for x, p in zip(x_values, pareto_mask) if p]
|
|
523
|
+
pareto_y = [y for y, p in zip(y_values, pareto_mask) if p]
|
|
524
|
+
|
|
525
|
+
# Sort by x for nice line
|
|
526
|
+
sorted_pairs = sorted(zip(pareto_x, pareto_y))
|
|
527
|
+
pareto_x_sorted = [p[0] for p in sorted_pairs]
|
|
528
|
+
pareto_y_sorted = [p[1] for p in sorted_pairs]
|
|
529
|
+
|
|
530
|
+
fig.add_trace(
|
|
531
|
+
go.Scatter(
|
|
532
|
+
x=pareto_x_sorted,
|
|
533
|
+
y=pareto_y_sorted,
|
|
534
|
+
mode="lines",
|
|
535
|
+
line={"color": theme_config["colorway"][1], "width": 2, "dash": "dot"},
|
|
536
|
+
name="Pareto Frontier",
|
|
537
|
+
hoverinfo="skip",
|
|
538
|
+
)
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
# Add annotation for number of Pareto signals
|
|
542
|
+
fig.add_annotation(
|
|
543
|
+
x=0.02,
|
|
544
|
+
y=0.98,
|
|
545
|
+
xref="paper",
|
|
546
|
+
yref="paper",
|
|
547
|
+
text=f"Pareto optimal: {len(pareto_signals)} / {len(signal_names)}",
|
|
548
|
+
showarrow=False,
|
|
549
|
+
font={"size": 12},
|
|
550
|
+
bgcolor="rgba(255,255,255,0.8)",
|
|
551
|
+
bordercolor=theme_config["colorway"][0],
|
|
552
|
+
borderwidth=1,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
return fig
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def _compute_pareto_mask(
|
|
559
|
+
x_values: list[float],
|
|
560
|
+
y_values: list[float],
|
|
561
|
+
minimize_x: bool = True,
|
|
562
|
+
maximize_y: bool = True,
|
|
563
|
+
) -> list[bool]:
|
|
564
|
+
"""Compute Pareto frontier mask.
|
|
565
|
+
|
|
566
|
+
Returns True for points on the Pareto frontier (non-dominated).
|
|
567
|
+
"""
|
|
568
|
+
n = len(x_values)
|
|
569
|
+
is_pareto = [True] * n
|
|
570
|
+
|
|
571
|
+
for i in range(n):
|
|
572
|
+
if not is_pareto[i]:
|
|
573
|
+
continue
|
|
574
|
+
|
|
575
|
+
for j in range(n):
|
|
576
|
+
if i == j:
|
|
577
|
+
continue
|
|
578
|
+
|
|
579
|
+
# Check if j dominates i
|
|
580
|
+
x_better = x_values[j] <= x_values[i] if minimize_x else x_values[j] >= x_values[i]
|
|
581
|
+
y_better = y_values[j] >= y_values[i] if maximize_y else y_values[j] <= y_values[i]
|
|
582
|
+
|
|
583
|
+
x_strictly = x_values[j] < x_values[i] if minimize_x else x_values[j] > x_values[i]
|
|
584
|
+
y_strictly = y_values[j] > y_values[i] if maximize_y else y_values[j] < y_values[i]
|
|
585
|
+
|
|
586
|
+
# j dominates i if j is at least as good in both and strictly better in one
|
|
587
|
+
if x_better and y_better and (x_strictly or y_strictly):
|
|
588
|
+
is_pareto[i] = False
|
|
589
|
+
break
|
|
590
|
+
|
|
591
|
+
return is_pareto
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
# =============================================================================
|
|
595
|
+
# Module Exports
|
|
596
|
+
# =============================================================================
|
|
597
|
+
|
|
598
|
+
__all__ = [
|
|
599
|
+
"plot_ic_ridge",
|
|
600
|
+
"plot_signal_ranking_bar",
|
|
601
|
+
"plot_signal_correlation_heatmap",
|
|
602
|
+
"plot_pareto_frontier",
|
|
603
|
+
]
|