ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1050 @@
|
|
|
1
|
+
"""Interactive visualizations for ml4t-diagnostic evaluation results.
|
|
2
|
+
|
|
3
|
+
This module provides Plotly-based visualizations for the Three-Tier
|
|
4
|
+
Validation Framework, including IC heatmaps, quantile analysis,
|
|
5
|
+
and comprehensive evaluation dashboards.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Union, cast
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import plotly.express as px
|
|
13
|
+
import plotly.graph_objects as go
|
|
14
|
+
import polars as pl
|
|
15
|
+
from plotly.subplots import make_subplots
|
|
16
|
+
|
|
17
|
+
from ml4t.diagnostic.backends.polars_backend import PolarsBackend
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from numpy.typing import NDArray
|
|
21
|
+
|
|
22
|
+
# Color schemes for financial data
|
|
23
|
+
COLORS = {
|
|
24
|
+
"positive": "#00CC88", # Green for positive returns
|
|
25
|
+
"negative": "#FF4444", # Red for negative returns
|
|
26
|
+
"neutral": "#888888", # Gray for neutral
|
|
27
|
+
"primary": "#3366CC", # Blue for primary data
|
|
28
|
+
"secondary": "#FF9900", # Orange for secondary data
|
|
29
|
+
"background": "#F8F9FA",
|
|
30
|
+
"grid": "#E0E0E0",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
# Plotly theme configuration
|
|
34
|
+
DEFAULT_LAYOUT = {
|
|
35
|
+
"font": {"family": "Arial, sans-serif", "size": 12},
|
|
36
|
+
"plot_bgcolor": COLORS["background"],
|
|
37
|
+
"paper_bgcolor": "white",
|
|
38
|
+
"hovermode": "closest",
|
|
39
|
+
"margin": {"l": 60, "r": 30, "t": 50, "b": 60},
|
|
40
|
+
"xaxis": {"gridcolor": COLORS["grid"], "zeroline": False},
|
|
41
|
+
"yaxis": {"gridcolor": COLORS["grid"], "zeroline": False},
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def plot_ic_heatmap(
|
|
46
|
+
predictions: Union[pd.DataFrame, "NDArray[Any]"],
|
|
47
|
+
returns: Union[pd.DataFrame, "NDArray[Any]"],
|
|
48
|
+
horizons: list[int] | None = None,
|
|
49
|
+
time_index: pd.DatetimeIndex | None = None,
|
|
50
|
+
regime_column: str | None = None,
|
|
51
|
+
title: str = "Information Coefficient Term Structure",
|
|
52
|
+
colorscale: str = "RdBu",
|
|
53
|
+
_use_optimized: bool = True,
|
|
54
|
+
use_streaming: bool = True,
|
|
55
|
+
) -> go.Figure:
|
|
56
|
+
"""Create interactive IC heatmap across multiple forward return horizons.
|
|
57
|
+
|
|
58
|
+
This visualization shows how predictive power varies across different
|
|
59
|
+
prediction horizons, helping identify the optimal holding period.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
predictions : pd.DataFrame or ndarray
|
|
64
|
+
Model predictions (same for all horizons)
|
|
65
|
+
returns : pd.DataFrame or ndarray
|
|
66
|
+
Forward returns for different horizons (columns = horizons)
|
|
67
|
+
horizons : list[int], optional
|
|
68
|
+
List of forward return horizons. If None, uses column names
|
|
69
|
+
time_index : pd.DatetimeIndex, optional
|
|
70
|
+
Time index for x-axis. If None, uses integer index
|
|
71
|
+
regime_column : str, optional
|
|
72
|
+
Column name for market regime filtering
|
|
73
|
+
title : str, default "Information Coefficient Term Structure"
|
|
74
|
+
Plot title
|
|
75
|
+
colorscale : str, default "RdBu"
|
|
76
|
+
Plotly colorscale name
|
|
77
|
+
use_optimized : bool, default True
|
|
78
|
+
Whether to use optimized Polars backend (always True for performance)
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
-------
|
|
82
|
+
go.Figure
|
|
83
|
+
Interactive Plotly figure
|
|
84
|
+
|
|
85
|
+
Examples:
|
|
86
|
+
--------
|
|
87
|
+
>>> # Simple usage
|
|
88
|
+
>>> fig = plot_ic_heatmap(predictions, forward_returns)
|
|
89
|
+
>>> fig.show()
|
|
90
|
+
|
|
91
|
+
>>> # With custom horizons
|
|
92
|
+
>>> fig = plot_ic_heatmap(
|
|
93
|
+
... predictions,
|
|
94
|
+
... returns_df,
|
|
95
|
+
... horizons=[1, 5, 10, 20],
|
|
96
|
+
... time_index=returns_df.index
|
|
97
|
+
... )
|
|
98
|
+
"""
|
|
99
|
+
# Convert inputs to appropriate types
|
|
100
|
+
predictions_data: pd.Series | pd.DataFrame | NDArray[Any]
|
|
101
|
+
if isinstance(predictions, np.ndarray):
|
|
102
|
+
predictions_data = pd.Series(predictions, name="predictions")
|
|
103
|
+
else:
|
|
104
|
+
predictions_data = predictions
|
|
105
|
+
|
|
106
|
+
returns_data: pd.DataFrame | NDArray[Any]
|
|
107
|
+
if isinstance(returns, np.ndarray):
|
|
108
|
+
returns_data = (
|
|
109
|
+
pd.DataFrame(returns, columns=cast(Any, ["returns"]))
|
|
110
|
+
if returns.ndim == 1
|
|
111
|
+
else pd.DataFrame(returns)
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
returns_data = returns
|
|
115
|
+
|
|
116
|
+
# Determine horizons as strings for internal processing
|
|
117
|
+
horizons_str: list[str]
|
|
118
|
+
if horizons is None:
|
|
119
|
+
if isinstance(returns_data, pd.DataFrame):
|
|
120
|
+
horizons_str = [str(col) for col in returns_data.columns]
|
|
121
|
+
else:
|
|
122
|
+
horizons_str = ["1"]
|
|
123
|
+
else:
|
|
124
|
+
horizons_str = [str(h) for h in horizons]
|
|
125
|
+
|
|
126
|
+
# Calculate rolling IC for each horizon
|
|
127
|
+
window_size = min(60, len(predictions_data) // 4) # Adaptive window
|
|
128
|
+
|
|
129
|
+
# Convert Series to DataFrame for _compute_ic_matrix_optimized
|
|
130
|
+
pred_for_ic: pd.DataFrame | NDArray[Any]
|
|
131
|
+
if isinstance(predictions_data, pd.Series):
|
|
132
|
+
pred_for_ic = predictions_data.to_frame()
|
|
133
|
+
elif isinstance(predictions_data, pd.DataFrame):
|
|
134
|
+
pred_for_ic = predictions_data
|
|
135
|
+
else:
|
|
136
|
+
pred_for_ic = predictions_data
|
|
137
|
+
|
|
138
|
+
# Use optimized Polars implementation for all cases
|
|
139
|
+
ic_matrix = _compute_ic_matrix_optimized(
|
|
140
|
+
pred_for_ic,
|
|
141
|
+
returns_data,
|
|
142
|
+
horizons_str,
|
|
143
|
+
window_size,
|
|
144
|
+
use_streaming,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Create time index
|
|
148
|
+
x_values: pd.Index | pd.DatetimeIndex
|
|
149
|
+
if time_index is not None:
|
|
150
|
+
x_values = time_index[window_size:]
|
|
151
|
+
else:
|
|
152
|
+
x_values = pd.Index(list(range(window_size, len(predictions_data))))
|
|
153
|
+
|
|
154
|
+
# Create heatmap
|
|
155
|
+
fig = go.Figure(
|
|
156
|
+
data=go.Heatmap(
|
|
157
|
+
z=ic_matrix,
|
|
158
|
+
x=x_values,
|
|
159
|
+
y=[f"{h}d" for h in horizons_str],
|
|
160
|
+
colorscale=colorscale,
|
|
161
|
+
zmid=0,
|
|
162
|
+
text=np.round(ic_matrix, 3),
|
|
163
|
+
texttemplate="%{text}",
|
|
164
|
+
textfont={"size": 10},
|
|
165
|
+
hovertemplate="Horizon: %{y}<br>Time: %{x}<br>IC: %{z:.3f}<extra></extra>",
|
|
166
|
+
colorbar={"title": "IC", "tickmode": "linear", "tick0": -1, "dtick": 0.2},
|
|
167
|
+
),
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Update layout
|
|
171
|
+
fig.update_layout(
|
|
172
|
+
title={"text": title, "x": 0.5, "xanchor": "center"},
|
|
173
|
+
xaxis_title="Date" if time_index is not None else "Time",
|
|
174
|
+
yaxis_title="Forward Return Horizon",
|
|
175
|
+
**DEFAULT_LAYOUT,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Add regime filtering if specified
|
|
179
|
+
if regime_column is not None:
|
|
180
|
+
# This would add dropdown for regime filtering
|
|
181
|
+
# Implementation depends on regime data structure
|
|
182
|
+
pass
|
|
183
|
+
|
|
184
|
+
return fig
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _compute_ic_matrix_optimized(
|
|
188
|
+
predictions: Union[pd.DataFrame, "NDArray[Any]"],
|
|
189
|
+
returns: Union[pd.DataFrame, "NDArray[Any]"],
|
|
190
|
+
horizons: list[str],
|
|
191
|
+
window_size: int,
|
|
192
|
+
use_streaming: bool = True,
|
|
193
|
+
) -> list[list[float]]:
|
|
194
|
+
"""Compute IC matrix using optimized Polars operations with streaming for large datasets.
|
|
195
|
+
|
|
196
|
+
Parameters
|
|
197
|
+
----------
|
|
198
|
+
predictions : Union[pd.DataFrame, NDArray]
|
|
199
|
+
Model predictions
|
|
200
|
+
returns : Union[pd.DataFrame, NDArray]
|
|
201
|
+
Returns data for different horizons
|
|
202
|
+
horizons : list[str]
|
|
203
|
+
List of horizon labels
|
|
204
|
+
window_size : int
|
|
205
|
+
Rolling window size for IC calculation
|
|
206
|
+
use_streaming : bool, default True
|
|
207
|
+
Whether to use streaming for large datasets (>100k samples)
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
-------
|
|
211
|
+
list[list[float]]
|
|
212
|
+
IC matrix with shape (n_horizons, n_time_points)
|
|
213
|
+
"""
|
|
214
|
+
# Convert to Polars DataFrame
|
|
215
|
+
data_dict: dict[str, NDArray[Any]] = {}
|
|
216
|
+
|
|
217
|
+
# Handle predictions
|
|
218
|
+
if isinstance(predictions, np.ndarray):
|
|
219
|
+
pred_array = predictions.flatten()
|
|
220
|
+
elif hasattr(predictions, "values"):
|
|
221
|
+
pred_array = predictions.values.flatten()
|
|
222
|
+
else:
|
|
223
|
+
pred_array = np.array(predictions).flatten()
|
|
224
|
+
|
|
225
|
+
data_dict["predictions"] = pred_array
|
|
226
|
+
n_samples = len(pred_array)
|
|
227
|
+
|
|
228
|
+
# Handle returns
|
|
229
|
+
if isinstance(returns, pd.DataFrame):
|
|
230
|
+
for i, horizon in enumerate(horizons):
|
|
231
|
+
if i < returns.shape[1]:
|
|
232
|
+
data_dict[f"returns_{horizon}"] = returns.iloc[:, i].to_numpy()
|
|
233
|
+
else:
|
|
234
|
+
data_dict[f"returns_{horizon}"] = returns.iloc[:, 0].to_numpy()
|
|
235
|
+
elif isinstance(returns, np.ndarray):
|
|
236
|
+
if returns.ndim == 2:
|
|
237
|
+
for i, horizon in enumerate(horizons):
|
|
238
|
+
if i < returns.shape[1]:
|
|
239
|
+
data_dict[f"returns_{horizon}"] = returns[:, i]
|
|
240
|
+
else:
|
|
241
|
+
data_dict[f"returns_{horizon}"] = returns[:, 0]
|
|
242
|
+
else:
|
|
243
|
+
for horizon in horizons:
|
|
244
|
+
data_dict[f"returns_{horizon}"] = returns
|
|
245
|
+
else:
|
|
246
|
+
# Assume single series
|
|
247
|
+
ret_array = np.array(returns).flatten()
|
|
248
|
+
for horizon in horizons:
|
|
249
|
+
data_dict[f"returns_{horizon}"] = ret_array
|
|
250
|
+
|
|
251
|
+
# Create Polars DataFrame
|
|
252
|
+
df = pl.DataFrame(data_dict)
|
|
253
|
+
|
|
254
|
+
# Choose appropriate method based on dataset size and streaming preference
|
|
255
|
+
returns_matrix = df.select([f"returns_{h}" for h in horizons])
|
|
256
|
+
min_periods = max(2, window_size // 2)
|
|
257
|
+
|
|
258
|
+
if use_streaming and n_samples > 100000:
|
|
259
|
+
# Use streaming method for large datasets
|
|
260
|
+
ic_results = PolarsBackend.fast_multi_horizon_ic_streaming(
|
|
261
|
+
df["predictions"],
|
|
262
|
+
returns_matrix,
|
|
263
|
+
window_size,
|
|
264
|
+
min_periods=min_periods,
|
|
265
|
+
chunk_size=PolarsBackend.adaptive_chunk_size(
|
|
266
|
+
n_samples,
|
|
267
|
+
len(horizons) + 1,
|
|
268
|
+
target_memory_mb=500,
|
|
269
|
+
),
|
|
270
|
+
)
|
|
271
|
+
else:
|
|
272
|
+
# Use standard method for smaller datasets
|
|
273
|
+
ic_results = PolarsBackend.fast_multi_horizon_ic(
|
|
274
|
+
df["predictions"],
|
|
275
|
+
returns_matrix,
|
|
276
|
+
window_size,
|
|
277
|
+
min_periods=min_periods,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Extract IC matrix
|
|
281
|
+
ic_matrix = []
|
|
282
|
+
for horizon in horizons:
|
|
283
|
+
ic_series = ic_results[f"ic_returns_{horizon}"]
|
|
284
|
+
# Remove initial NaN values and convert to list
|
|
285
|
+
ic_values = ic_series.drop_nulls().to_list()
|
|
286
|
+
# Trim to remove window startup
|
|
287
|
+
if len(ic_values) > window_size:
|
|
288
|
+
ic_values = ic_values[window_size:]
|
|
289
|
+
ic_matrix.append(ic_values)
|
|
290
|
+
|
|
291
|
+
return ic_matrix
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def plot_quantile_returns(
|
|
295
|
+
predictions: Union[pd.Series, "NDArray[Any]"],
|
|
296
|
+
returns: Union[pd.Series, "NDArray[Any]"],
|
|
297
|
+
n_quantiles: int = 5,
|
|
298
|
+
show_cumulative: bool = True,
|
|
299
|
+
title: str = "Returns by Prediction Quantile",
|
|
300
|
+
) -> go.Figure:
|
|
301
|
+
"""Create quantile bar chart with optional cumulative returns.
|
|
302
|
+
|
|
303
|
+
This visualization shows average returns for each prediction quantile,
|
|
304
|
+
helping validate monotonic relationships between predictions and outcomes.
|
|
305
|
+
|
|
306
|
+
Parameters
|
|
307
|
+
----------
|
|
308
|
+
predictions : pd.Series or ndarray
|
|
309
|
+
Model predictions
|
|
310
|
+
returns : pd.Series or ndarray
|
|
311
|
+
Actual returns
|
|
312
|
+
n_quantiles : int, default 5
|
|
313
|
+
Number of quantiles to create
|
|
314
|
+
show_cumulative : bool, default True
|
|
315
|
+
Whether to show cumulative returns subplot
|
|
316
|
+
title : str
|
|
317
|
+
Plot title
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
-------
|
|
321
|
+
go.Figure
|
|
322
|
+
Interactive Plotly figure with quantile analysis
|
|
323
|
+
"""
|
|
324
|
+
# Store original index if available
|
|
325
|
+
time_index = None
|
|
326
|
+
if isinstance(returns, pd.Series):
|
|
327
|
+
time_index = returns.index
|
|
328
|
+
elif isinstance(predictions, pd.Series):
|
|
329
|
+
time_index = predictions.index
|
|
330
|
+
|
|
331
|
+
# Convert to numpy arrays for consistent processing
|
|
332
|
+
pred_arr: NDArray[Any]
|
|
333
|
+
ret_arr: NDArray[Any]
|
|
334
|
+
if isinstance(predictions, pd.Series):
|
|
335
|
+
pred_arr = predictions.to_numpy()
|
|
336
|
+
else:
|
|
337
|
+
pred_arr = predictions
|
|
338
|
+
if isinstance(returns, pd.Series):
|
|
339
|
+
ret_arr = returns.to_numpy()
|
|
340
|
+
else:
|
|
341
|
+
ret_arr = returns
|
|
342
|
+
|
|
343
|
+
# Handle edge cases
|
|
344
|
+
if len(pred_arr) == 0 or len(ret_arr) == 0:
|
|
345
|
+
# Return empty figure
|
|
346
|
+
fig = go.Figure()
|
|
347
|
+
fig.update_layout(title=title)
|
|
348
|
+
return fig
|
|
349
|
+
|
|
350
|
+
# Check for all NaN
|
|
351
|
+
if np.all(np.isnan(pred_arr)) or np.all(np.isnan(ret_arr)):
|
|
352
|
+
# Return empty figure with message
|
|
353
|
+
fig = go.Figure()
|
|
354
|
+
fig.add_annotation(
|
|
355
|
+
text="No valid data to display",
|
|
356
|
+
xref="paper",
|
|
357
|
+
yref="paper",
|
|
358
|
+
x=0.5,
|
|
359
|
+
y=0.5,
|
|
360
|
+
showarrow=False,
|
|
361
|
+
)
|
|
362
|
+
fig.update_layout(title=title)
|
|
363
|
+
return fig
|
|
364
|
+
|
|
365
|
+
# Create quantiles
|
|
366
|
+
quantile_labels: NDArray[Any]
|
|
367
|
+
try:
|
|
368
|
+
quantile_result = pd.qcut(pred_arr, n_quantiles, labels=False, duplicates="drop") + 1
|
|
369
|
+
quantile_labels = (
|
|
370
|
+
quantile_result.to_numpy()
|
|
371
|
+
if hasattr(quantile_result, "to_numpy")
|
|
372
|
+
else np.array(quantile_result)
|
|
373
|
+
)
|
|
374
|
+
except ValueError:
|
|
375
|
+
# If can't create quantiles, use equal splits
|
|
376
|
+
quantile_labels = np.linspace(1, n_quantiles, len(pred_arr), dtype=int)
|
|
377
|
+
|
|
378
|
+
# Calculate mean returns per quantile
|
|
379
|
+
quantile_returns = []
|
|
380
|
+
quantile_counts: list[int] = []
|
|
381
|
+
std_errors = []
|
|
382
|
+
|
|
383
|
+
for q in range(1, n_quantiles + 1):
|
|
384
|
+
mask = quantile_labels == q
|
|
385
|
+
q_returns = ret_arr[mask]
|
|
386
|
+
quantile_returns.append(np.mean(q_returns))
|
|
387
|
+
quantile_counts.append(np.sum(mask))
|
|
388
|
+
std_errors.append(np.std(q_returns) / np.sqrt(len(q_returns)))
|
|
389
|
+
|
|
390
|
+
# Create figure
|
|
391
|
+
if show_cumulative:
|
|
392
|
+
fig = make_subplots(
|
|
393
|
+
rows=2,
|
|
394
|
+
cols=1,
|
|
395
|
+
row_heights=[0.6, 0.4],
|
|
396
|
+
shared_xaxes=True,
|
|
397
|
+
vertical_spacing=0.1,
|
|
398
|
+
subplot_titles=("Mean Returns by Quantile", "Cumulative Returns"),
|
|
399
|
+
)
|
|
400
|
+
else:
|
|
401
|
+
fig = go.Figure()
|
|
402
|
+
|
|
403
|
+
# Colors based on return sign
|
|
404
|
+
colors = [COLORS["positive"] if r > 0 else COLORS["negative"] for r in quantile_returns]
|
|
405
|
+
|
|
406
|
+
# Add bar chart
|
|
407
|
+
bar_trace = go.Bar(
|
|
408
|
+
x=list(range(1, n_quantiles + 1)),
|
|
409
|
+
y=quantile_returns,
|
|
410
|
+
error_y={"type": "data", "array": std_errors, "visible": True},
|
|
411
|
+
marker_color=colors,
|
|
412
|
+
text=[f"{r:.2%}" for r in quantile_returns],
|
|
413
|
+
textposition="outside",
|
|
414
|
+
hovertemplate=(
|
|
415
|
+
"Quantile %{x}<br>Mean Return: %{y:.2%}<br>Count: %{customdata}<extra></extra>"
|
|
416
|
+
),
|
|
417
|
+
customdata=quantile_counts,
|
|
418
|
+
name="Mean Return",
|
|
419
|
+
showlegend=False,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
if show_cumulative:
|
|
423
|
+
fig.add_trace(bar_trace, row=1, col=1)
|
|
424
|
+
|
|
425
|
+
# Calculate cumulative returns for each quantile with proper time alignment
|
|
426
|
+
for q in range(1, n_quantiles + 1):
|
|
427
|
+
mask = quantile_labels == q
|
|
428
|
+
|
|
429
|
+
# If we have a time index, use it for proper alignment
|
|
430
|
+
if time_index is not None:
|
|
431
|
+
# Get returns and their corresponding times
|
|
432
|
+
q_indices = np.where(mask)[0]
|
|
433
|
+
# Convert to numpy to avoid pandas index issues with positional sorting
|
|
434
|
+
q_returns_arr = ret_arr[mask]
|
|
435
|
+
q_times = time_index[q_indices]
|
|
436
|
+
|
|
437
|
+
# Sort by time
|
|
438
|
+
time_order = np.argsort(q_times)
|
|
439
|
+
q_returns_sorted = q_returns_arr[time_order]
|
|
440
|
+
q_times_sorted = q_times[time_order]
|
|
441
|
+
|
|
442
|
+
# Calculate cumulative returns on time-sorted data
|
|
443
|
+
cumulative = np.cumprod(1 + q_returns_sorted) - 1
|
|
444
|
+
|
|
445
|
+
fig.add_trace(
|
|
446
|
+
go.Scatter(
|
|
447
|
+
x=q_times_sorted,
|
|
448
|
+
y=cumulative,
|
|
449
|
+
mode="lines",
|
|
450
|
+
name=f"Q{q}",
|
|
451
|
+
line={"width": 2},
|
|
452
|
+
hovertemplate=(
|
|
453
|
+
"Quantile %{fullData.name}<br>Time: %{x}<br>Cumulative: %{y:.2%}<extra></extra>"
|
|
454
|
+
),
|
|
455
|
+
),
|
|
456
|
+
row=2,
|
|
457
|
+
col=1,
|
|
458
|
+
)
|
|
459
|
+
else:
|
|
460
|
+
# Fallback to position-based if no time index
|
|
461
|
+
q_returns_arr = ret_arr[mask]
|
|
462
|
+
cumulative = np.cumprod(1 + q_returns_arr) - 1
|
|
463
|
+
|
|
464
|
+
fig.add_trace(
|
|
465
|
+
go.Scatter(
|
|
466
|
+
x=np.arange(len(cumulative)),
|
|
467
|
+
y=cumulative,
|
|
468
|
+
mode="lines",
|
|
469
|
+
name=f"Q{q}",
|
|
470
|
+
line={"width": 2},
|
|
471
|
+
hovertemplate=(
|
|
472
|
+
"Quantile %{fullData.name}<br>Position: %{x}<br>Cumulative: %{y:.2%}<extra></extra>"
|
|
473
|
+
),
|
|
474
|
+
),
|
|
475
|
+
row=2,
|
|
476
|
+
col=1,
|
|
477
|
+
)
|
|
478
|
+
else:
|
|
479
|
+
fig.add_trace(bar_trace)
|
|
480
|
+
|
|
481
|
+
# Update layout
|
|
482
|
+
fig.update_xaxes(
|
|
483
|
+
title_text="Prediction Quantile",
|
|
484
|
+
row=2 if show_cumulative else 1,
|
|
485
|
+
col=1,
|
|
486
|
+
)
|
|
487
|
+
fig.update_yaxes(title_text="Mean Return", tickformat=".1%", row=1, col=1)
|
|
488
|
+
|
|
489
|
+
if show_cumulative:
|
|
490
|
+
fig.update_yaxes(title_text="Cumulative Return", tickformat=".1%", row=2, col=1)
|
|
491
|
+
# Update x-axis label based on whether we have time index
|
|
492
|
+
x_label = "Time" if time_index is not None else "Position"
|
|
493
|
+
fig.update_xaxes(title_text=x_label, row=2, col=1)
|
|
494
|
+
|
|
495
|
+
fig.update_layout(title={"text": title, "x": 0.5, "xanchor": "center"}, **DEFAULT_LAYOUT)
|
|
496
|
+
|
|
497
|
+
return fig
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def plot_turnover_decay(
|
|
501
|
+
factor_values: pd.DataFrame,
|
|
502
|
+
quantiles: int = 5,
|
|
503
|
+
lags: list[int] | None = None,
|
|
504
|
+
title: str = "Factor Turnover and Decay Analysis",
|
|
505
|
+
) -> go.Figure:
|
|
506
|
+
"""Visualize factor stability through turnover and autocorrelation analysis.
|
|
507
|
+
|
|
508
|
+
Parameters
|
|
509
|
+
----------
|
|
510
|
+
factor_values : pd.DataFrame
|
|
511
|
+
Time series of factor values (index = time, columns = assets)
|
|
512
|
+
quantiles : int, default 5
|
|
513
|
+
Number of quantiles for turnover calculation
|
|
514
|
+
lags : list[int], optional
|
|
515
|
+
Autocorrelation lags to compute. Default [1, 5, 10, 20]
|
|
516
|
+
title : str
|
|
517
|
+
Plot title
|
|
518
|
+
|
|
519
|
+
Returns:
|
|
520
|
+
-------
|
|
521
|
+
go.Figure
|
|
522
|
+
Multi-panel figure showing turnover and decay analysis
|
|
523
|
+
"""
|
|
524
|
+
if lags is None:
|
|
525
|
+
lags = [1, 5, 10, 20]
|
|
526
|
+
|
|
527
|
+
# Create subplots
|
|
528
|
+
fig = make_subplots(
|
|
529
|
+
rows=2,
|
|
530
|
+
cols=2,
|
|
531
|
+
subplot_titles=(
|
|
532
|
+
"Quantile Turnover by Period",
|
|
533
|
+
"Average Autocorrelation Decay",
|
|
534
|
+
"Turnover Heatmap",
|
|
535
|
+
"Signal Stability",
|
|
536
|
+
),
|
|
537
|
+
specs=[
|
|
538
|
+
[{"type": "bar"}, {"type": "scatter"}],
|
|
539
|
+
[{"type": "heatmap"}, {"type": "scatter"}],
|
|
540
|
+
],
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# Calculate quantile assignments
|
|
544
|
+
quantile_assignments = factor_values.apply(
|
|
545
|
+
lambda x: pd.qcut(x, quantiles, labels=False, duplicates="drop"),
|
|
546
|
+
axis=0,
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
# 1. Calculate turnover for each quantile
|
|
550
|
+
turnover_by_quantile = []
|
|
551
|
+
for q in range(quantiles):
|
|
552
|
+
# Count changes in quantile assignment
|
|
553
|
+
in_quantile = (quantile_assignments == q).astype(int)
|
|
554
|
+
changes = in_quantile.diff().abs().sum(axis=1)
|
|
555
|
+
total = in_quantile.sum(axis=1)
|
|
556
|
+
turnover = (changes / (2 * total)).fillna(0).mean()
|
|
557
|
+
turnover_by_quantile.append(turnover)
|
|
558
|
+
|
|
559
|
+
# Add turnover bar chart
|
|
560
|
+
fig.add_trace(
|
|
561
|
+
go.Bar(
|
|
562
|
+
x=list(range(1, quantiles + 1)),
|
|
563
|
+
y=turnover_by_quantile,
|
|
564
|
+
marker_color=COLORS["primary"],
|
|
565
|
+
text=[f"{t:.1%}" for t in turnover_by_quantile],
|
|
566
|
+
textposition="outside",
|
|
567
|
+
hovertemplate="Quantile %{x}<br>Turnover: %{y:.1%}<extra></extra>",
|
|
568
|
+
showlegend=False,
|
|
569
|
+
),
|
|
570
|
+
row=1,
|
|
571
|
+
col=1,
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
# 2. Calculate autocorrelation decay
|
|
575
|
+
autocorr_values = []
|
|
576
|
+
for lag in lags:
|
|
577
|
+
# Calculate autocorrelation for each asset
|
|
578
|
+
autocorr = factor_values.apply(
|
|
579
|
+
lambda x, current_lag=lag: x.autocorr(lag=current_lag),
|
|
580
|
+
)
|
|
581
|
+
autocorr_values.append(autocorr.mean())
|
|
582
|
+
|
|
583
|
+
# Add autocorrelation decay plot
|
|
584
|
+
fig.add_trace(
|
|
585
|
+
go.Scatter(
|
|
586
|
+
x=lags,
|
|
587
|
+
y=autocorr_values,
|
|
588
|
+
mode="lines+markers",
|
|
589
|
+
marker={"size": 10, "color": COLORS["secondary"]},
|
|
590
|
+
line={"width": 3, "color": COLORS["secondary"]},
|
|
591
|
+
hovertemplate="Lag %{x}<br>Autocorr: %{y:.3f}<extra></extra>",
|
|
592
|
+
showlegend=False,
|
|
593
|
+
),
|
|
594
|
+
row=1,
|
|
595
|
+
col=2,
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
# 3. Turnover heatmap (time vs quantile)
|
|
599
|
+
# Sample time periods for visualization
|
|
600
|
+
time_periods = min(20, len(factor_values) // 10)
|
|
601
|
+
sample_indices = np.linspace(0, len(factor_values) - 2, time_periods, dtype=int)
|
|
602
|
+
|
|
603
|
+
turnover_matrix = []
|
|
604
|
+
for idx in sample_indices:
|
|
605
|
+
period_turnover = []
|
|
606
|
+
for q in range(quantiles):
|
|
607
|
+
in_q_t0 = quantile_assignments.iloc[idx] == q
|
|
608
|
+
in_q_t1 = quantile_assignments.iloc[idx + 1] == q
|
|
609
|
+
stayed = (in_q_t0 & in_q_t1).sum()
|
|
610
|
+
total = in_q_t0.sum()
|
|
611
|
+
turnover = 1 - (stayed / total) if total > 0 else 0
|
|
612
|
+
period_turnover.append(turnover)
|
|
613
|
+
turnover_matrix.append(period_turnover)
|
|
614
|
+
|
|
615
|
+
fig.add_trace(
|
|
616
|
+
go.Heatmap(
|
|
617
|
+
z=turnover_matrix,
|
|
618
|
+
x=list(range(1, quantiles + 1)),
|
|
619
|
+
y=sample_indices,
|
|
620
|
+
colorscale="Reds",
|
|
621
|
+
hovertemplate=("Time: %{y}<br>Quantile: %{x}<br>Turnover: %{z:.1%}<extra></extra>"),
|
|
622
|
+
showscale=True,
|
|
623
|
+
colorbar={"title": "Turnover", "x": 1.15},
|
|
624
|
+
),
|
|
625
|
+
row=2,
|
|
626
|
+
col=1,
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
# 4. Signal stability (rolling mean of factor values)
|
|
630
|
+
rolling_mean = factor_values.mean(axis=1).rolling(window=20).mean()
|
|
631
|
+
rolling_std = factor_values.mean(axis=1).rolling(window=20).std()
|
|
632
|
+
|
|
633
|
+
fig.add_trace(
|
|
634
|
+
go.Scatter(
|
|
635
|
+
x=factor_values.index,
|
|
636
|
+
y=rolling_mean,
|
|
637
|
+
mode="lines",
|
|
638
|
+
line={"color": COLORS["primary"], "width": 2},
|
|
639
|
+
name="Rolling Mean",
|
|
640
|
+
hovertemplate="Time: %{x}<br>Mean: %{y:.3f}<extra></extra>",
|
|
641
|
+
),
|
|
642
|
+
row=2,
|
|
643
|
+
col=2,
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
# Add confidence bands
|
|
647
|
+
fig.add_trace(
|
|
648
|
+
go.Scatter(
|
|
649
|
+
x=factor_values.index,
|
|
650
|
+
y=rolling_mean + 2 * rolling_std,
|
|
651
|
+
mode="lines",
|
|
652
|
+
line={"width": 0},
|
|
653
|
+
showlegend=False,
|
|
654
|
+
hoverinfo="skip",
|
|
655
|
+
),
|
|
656
|
+
row=2,
|
|
657
|
+
col=2,
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
fig.add_trace(
|
|
661
|
+
go.Scatter(
|
|
662
|
+
x=factor_values.index,
|
|
663
|
+
y=rolling_mean - 2 * rolling_std,
|
|
664
|
+
mode="lines",
|
|
665
|
+
line={"width": 0},
|
|
666
|
+
fill="tonexty",
|
|
667
|
+
fillcolor="rgba(51, 102, 204, 0.2)",
|
|
668
|
+
name="±2 STD",
|
|
669
|
+
hoverinfo="skip",
|
|
670
|
+
),
|
|
671
|
+
row=2,
|
|
672
|
+
col=2,
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
# Update axes
|
|
676
|
+
fig.update_xaxes(title_text="Quantile", row=1, col=1)
|
|
677
|
+
fig.update_yaxes(title_text="Turnover Rate", tickformat=".0%", row=1, col=1)
|
|
678
|
+
|
|
679
|
+
fig.update_xaxes(title_text="Lag (days)", row=1, col=2)
|
|
680
|
+
fig.update_yaxes(title_text="Autocorrelation", row=1, col=2)
|
|
681
|
+
|
|
682
|
+
fig.update_xaxes(title_text="Quantile", row=2, col=1)
|
|
683
|
+
fig.update_yaxes(title_text="Time Period", row=2, col=1)
|
|
684
|
+
|
|
685
|
+
fig.update_xaxes(title_text="Date", row=2, col=2)
|
|
686
|
+
fig.update_yaxes(title_text="Factor Value", row=2, col=2)
|
|
687
|
+
|
|
688
|
+
# Update layout
|
|
689
|
+
fig.update_layout(
|
|
690
|
+
title={"text": title, "x": 0.5, "xanchor": "center"},
|
|
691
|
+
height=800,
|
|
692
|
+
**DEFAULT_LAYOUT,
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
return fig
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
def plot_feature_distributions(
|
|
699
|
+
features: pd.DataFrame,
|
|
700
|
+
n_periods: int = 4,
|
|
701
|
+
method: str = "box",
|
|
702
|
+
title: str = "Feature Distribution Analysis",
|
|
703
|
+
) -> go.Figure:
|
|
704
|
+
"""Create small multiples showing feature distributions over time.
|
|
705
|
+
|
|
706
|
+
Parameters
|
|
707
|
+
----------
|
|
708
|
+
features : pd.DataFrame
|
|
709
|
+
Feature values (index = time, columns = features)
|
|
710
|
+
n_periods : int, default 4
|
|
711
|
+
Number of time periods to show
|
|
712
|
+
method : str, default "box"
|
|
713
|
+
Plot type: "box", "violin", or "hist"
|
|
714
|
+
title : str
|
|
715
|
+
Plot title
|
|
716
|
+
|
|
717
|
+
Returns:
|
|
718
|
+
-------
|
|
719
|
+
go.Figure
|
|
720
|
+
Small multiples visualization
|
|
721
|
+
"""
|
|
722
|
+
# Limit to first 9 features for readability
|
|
723
|
+
n_features = min(9, features.shape[1])
|
|
724
|
+
feature_cols = features.columns[:n_features]
|
|
725
|
+
|
|
726
|
+
# Create time buckets
|
|
727
|
+
period_size = len(features) // n_periods
|
|
728
|
+
periods = []
|
|
729
|
+
period_labels = []
|
|
730
|
+
|
|
731
|
+
for i in range(n_periods):
|
|
732
|
+
start_idx = i * period_size
|
|
733
|
+
end_idx = (i + 1) * period_size if i < n_periods - 1 else len(features)
|
|
734
|
+
periods.append((start_idx, end_idx))
|
|
735
|
+
|
|
736
|
+
if hasattr(features.index, "date"):
|
|
737
|
+
start_date = features.index[start_idx].strftime("%Y-%m")
|
|
738
|
+
end_date = features.index[end_idx - 1].strftime("%Y-%m")
|
|
739
|
+
period_labels.append(f"{start_date} to {end_date}")
|
|
740
|
+
else:
|
|
741
|
+
period_labels.append(f"Period {i + 1}")
|
|
742
|
+
|
|
743
|
+
# Create subplots
|
|
744
|
+
n_rows = int(np.ceil(n_features / 3))
|
|
745
|
+
fig = make_subplots(
|
|
746
|
+
rows=n_rows,
|
|
747
|
+
cols=3,
|
|
748
|
+
subplot_titles=[str(col) for col in feature_cols],
|
|
749
|
+
vertical_spacing=0.15,
|
|
750
|
+
horizontal_spacing=0.1,
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
# Add plots for each feature
|
|
754
|
+
for idx, feature in enumerate(feature_cols):
|
|
755
|
+
row = idx // 3 + 1
|
|
756
|
+
col = idx % 3 + 1
|
|
757
|
+
|
|
758
|
+
for period_idx, (start, end) in enumerate(periods):
|
|
759
|
+
period_data = features[feature].iloc[start:end]
|
|
760
|
+
|
|
761
|
+
if method == "box":
|
|
762
|
+
fig.add_trace(
|
|
763
|
+
go.Box(
|
|
764
|
+
y=period_data,
|
|
765
|
+
name=period_labels[period_idx],
|
|
766
|
+
marker_color=px.colors.qualitative.Set3[period_idx],
|
|
767
|
+
boxpoints="outliers",
|
|
768
|
+
showlegend=(idx == 0),
|
|
769
|
+
legendgroup=f"period{period_idx}",
|
|
770
|
+
hovertemplate="%{y:.3f}<extra></extra>",
|
|
771
|
+
),
|
|
772
|
+
row=row,
|
|
773
|
+
col=col,
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
elif method == "violin":
|
|
777
|
+
fig.add_trace(
|
|
778
|
+
go.Violin(
|
|
779
|
+
y=period_data,
|
|
780
|
+
name=period_labels[period_idx],
|
|
781
|
+
marker_color=px.colors.qualitative.Set3[period_idx],
|
|
782
|
+
box_visible=True,
|
|
783
|
+
meanline_visible=True,
|
|
784
|
+
showlegend=(idx == 0),
|
|
785
|
+
legendgroup=f"period{period_idx}",
|
|
786
|
+
hovertemplate="%{y:.3f}<extra></extra>",
|
|
787
|
+
),
|
|
788
|
+
row=row,
|
|
789
|
+
col=col,
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
elif method == "hist":
|
|
793
|
+
fig.add_trace(
|
|
794
|
+
go.Histogram(
|
|
795
|
+
x=period_data,
|
|
796
|
+
name=period_labels[period_idx],
|
|
797
|
+
marker_color=px.colors.qualitative.Set3[period_idx],
|
|
798
|
+
opacity=0.7,
|
|
799
|
+
showlegend=(idx == 0),
|
|
800
|
+
legendgroup=f"period{period_idx}",
|
|
801
|
+
hovertemplate="Value: %{x:.3f}<br>Count: %{y}<extra></extra>",
|
|
802
|
+
histnorm="probability",
|
|
803
|
+
),
|
|
804
|
+
row=row,
|
|
805
|
+
col=col,
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
# Update layout
|
|
809
|
+
fig.update_layout(
|
|
810
|
+
title={"text": title, "x": 0.5, "xanchor": "center"},
|
|
811
|
+
height=300 * n_rows,
|
|
812
|
+
showlegend=True,
|
|
813
|
+
legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
|
|
814
|
+
**DEFAULT_LAYOUT,
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
# Update axes
|
|
818
|
+
if method == "hist":
|
|
819
|
+
fig.update_xaxes(title_text="Value")
|
|
820
|
+
fig.update_yaxes(title_text="Probability")
|
|
821
|
+
else:
|
|
822
|
+
fig.update_yaxes(title_text="Value")
|
|
823
|
+
|
|
824
|
+
return fig
|
|
825
|
+
|
|
826
|
+
|
|
827
|
+
def plot_ic_decay(
|
|
828
|
+
decay_results: dict[str, Any],
|
|
829
|
+
show_half_life: bool = True,
|
|
830
|
+
show_optimal: bool = True,
|
|
831
|
+
title: str | None = None,
|
|
832
|
+
) -> go.Figure:
|
|
833
|
+
"""Plot IC decay curve with half-life and optimal horizon annotations.
|
|
834
|
+
|
|
835
|
+
Creates an interactive Plotly visualization showing how IC decays across
|
|
836
|
+
prediction horizons, with optional markers for half-life and optimal horizon.
|
|
837
|
+
|
|
838
|
+
Parameters
|
|
839
|
+
----------
|
|
840
|
+
decay_results : dict
|
|
841
|
+
Results from compute_ic_decay()
|
|
842
|
+
show_half_life : bool, default True
|
|
843
|
+
Show vertical line at estimated half-life
|
|
844
|
+
show_optimal : bool, default True
|
|
845
|
+
Show marker at optimal horizon
|
|
846
|
+
title : str | None, default None
|
|
847
|
+
Custom title for the plot. If None, uses "IC Decay Analysis"
|
|
848
|
+
|
|
849
|
+
Returns
|
|
850
|
+
-------
|
|
851
|
+
plotly.graph_objects.Figure
|
|
852
|
+
Interactive Plotly figure
|
|
853
|
+
|
|
854
|
+
Examples
|
|
855
|
+
--------
|
|
856
|
+
>>> from ml4t.diagnostic.evaluation.metrics import compute_ic_decay
|
|
857
|
+
>>> from ml4t.diagnostic.evaluation.visualization import plot_ic_decay
|
|
858
|
+
>>>
|
|
859
|
+
>>> # Compute decay
|
|
860
|
+
>>> decay = compute_ic_decay(pred_df, price_df, group_col="symbol")
|
|
861
|
+
>>>
|
|
862
|
+
>>> # Visualize
|
|
863
|
+
>>> fig = plot_ic_decay(decay)
|
|
864
|
+
>>> fig.show()
|
|
865
|
+
"""
|
|
866
|
+
horizons = decay_results["horizons"]
|
|
867
|
+
ic_by_horizon = decay_results["ic_by_horizon"]
|
|
868
|
+
half_life = decay_results.get("half_life")
|
|
869
|
+
optimal_horizon = decay_results.get("optimal_horizon")
|
|
870
|
+
|
|
871
|
+
# Extract IC values in order
|
|
872
|
+
ic_values = [ic_by_horizon[h] for h in horizons]
|
|
873
|
+
|
|
874
|
+
# Create figure
|
|
875
|
+
fig = go.Figure()
|
|
876
|
+
|
|
877
|
+
# Add IC decay curve
|
|
878
|
+
fig.add_trace(
|
|
879
|
+
go.Scatter(
|
|
880
|
+
x=horizons,
|
|
881
|
+
y=ic_values,
|
|
882
|
+
mode="lines+markers",
|
|
883
|
+
name="IC",
|
|
884
|
+
line={"color": COLORS["primary"], "width": 2},
|
|
885
|
+
marker={"size": 8, "color": COLORS["primary"]},
|
|
886
|
+
hovertemplate="Horizon: %{x} days<br>IC: %{y:.4f}<extra></extra>",
|
|
887
|
+
)
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
# Add zero line for reference
|
|
891
|
+
fig.add_hline(y=0, line={"color": COLORS["grid"], "width": 1, "dash": "dash"})
|
|
892
|
+
|
|
893
|
+
# Add half-life marker
|
|
894
|
+
if show_half_life and half_life is not None:
|
|
895
|
+
# Calculate IC at half-life for the marker
|
|
896
|
+
if horizons[0] in ic_by_horizon:
|
|
897
|
+
initial_ic = ic_by_horizon[horizons[0]]
|
|
898
|
+
half_life_ic = initial_ic * 0.5
|
|
899
|
+
|
|
900
|
+
fig.add_vline(
|
|
901
|
+
x=half_life,
|
|
902
|
+
line={"color": COLORS["secondary"], "width": 2, "dash": "dash"},
|
|
903
|
+
annotation_text=f"Half-life: {half_life:.1f}d",
|
|
904
|
+
annotation_position="top right",
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
# Add marker at half-life point
|
|
908
|
+
fig.add_trace(
|
|
909
|
+
go.Scatter(
|
|
910
|
+
x=[half_life],
|
|
911
|
+
y=[half_life_ic],
|
|
912
|
+
mode="markers",
|
|
913
|
+
name="Half-life",
|
|
914
|
+
marker={"size": 12, "color": COLORS["secondary"], "symbol": "diamond"},
|
|
915
|
+
hovertemplate=f"Half-life: {half_life:.1f} days<br>IC: {half_life_ic:.4f}<extra></extra>",
|
|
916
|
+
)
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
# Add optimal horizon marker
|
|
920
|
+
if show_optimal and optimal_horizon is not None:
|
|
921
|
+
optimal_ic = ic_by_horizon[optimal_horizon]
|
|
922
|
+
|
|
923
|
+
fig.add_trace(
|
|
924
|
+
go.Scatter(
|
|
925
|
+
x=[optimal_horizon],
|
|
926
|
+
y=[optimal_ic],
|
|
927
|
+
mode="markers",
|
|
928
|
+
name="Optimal",
|
|
929
|
+
marker={
|
|
930
|
+
"size": 15,
|
|
931
|
+
"color": COLORS["positive"],
|
|
932
|
+
"symbol": "star",
|
|
933
|
+
"line": {"width": 2, "color": "white"},
|
|
934
|
+
},
|
|
935
|
+
hovertemplate=f"Optimal: {optimal_horizon} days<br>IC: {optimal_ic:.4f}<extra></extra>",
|
|
936
|
+
)
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
# Update layout
|
|
940
|
+
if title is None:
|
|
941
|
+
title = "IC Decay Analysis"
|
|
942
|
+
|
|
943
|
+
fig.update_layout(
|
|
944
|
+
title=title,
|
|
945
|
+
xaxis_title="Forecast Horizon (days)",
|
|
946
|
+
yaxis_title="Information Coefficient",
|
|
947
|
+
showlegend=True,
|
|
948
|
+
legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
|
|
949
|
+
**DEFAULT_LAYOUT,
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
return fig
|
|
953
|
+
|
|
954
|
+
|
|
955
|
+
def plot_monotonicity(
|
|
956
|
+
monotonicity_results: dict[str, Any],
|
|
957
|
+
title: str | None = None,
|
|
958
|
+
show_correlation: bool = True,
|
|
959
|
+
) -> go.Figure:
|
|
960
|
+
"""Plot quantile analysis for monotonicity testing.
|
|
961
|
+
|
|
962
|
+
Creates a bar chart showing mean outcomes across feature quantiles,
|
|
963
|
+
with annotations for monotonicity metrics.
|
|
964
|
+
|
|
965
|
+
Parameters
|
|
966
|
+
----------
|
|
967
|
+
monotonicity_results : dict
|
|
968
|
+
Results from compute_monotonicity()
|
|
969
|
+
title : str | None, default None
|
|
970
|
+
Custom title. If None, uses "Monotonicity Analysis"
|
|
971
|
+
show_correlation : bool, default True
|
|
972
|
+
Show correlation coefficient in subtitle
|
|
973
|
+
|
|
974
|
+
Returns
|
|
975
|
+
-------
|
|
976
|
+
plotly.graph_objects.Figure
|
|
977
|
+
Interactive Plotly figure
|
|
978
|
+
|
|
979
|
+
Examples
|
|
980
|
+
--------
|
|
981
|
+
>>> from ml4t.diagnostic.evaluation.metrics import compute_monotonicity
|
|
982
|
+
>>> from ml4t.diagnostic.evaluation.visualization import plot_monotonicity
|
|
983
|
+
>>>
|
|
984
|
+
>>> # Compute monotonicity
|
|
985
|
+
>>> result = compute_monotonicity(features, outcomes, n_quantiles=5)
|
|
986
|
+
>>>
|
|
987
|
+
>>> # Visualize
|
|
988
|
+
>>> fig = plot_monotonicity(result)
|
|
989
|
+
>>> fig.show()
|
|
990
|
+
"""
|
|
991
|
+
quantile_labels = monotonicity_results["quantile_labels"]
|
|
992
|
+
quantile_means = monotonicity_results["quantile_means"]
|
|
993
|
+
correlation = monotonicity_results["correlation"]
|
|
994
|
+
p_value = monotonicity_results["p_value"]
|
|
995
|
+
is_monotonic = monotonicity_results["is_monotonic"]
|
|
996
|
+
monotonicity_score = monotonicity_results["monotonicity_score"]
|
|
997
|
+
direction = monotonicity_results["direction"]
|
|
998
|
+
|
|
999
|
+
# Determine bar colors based on values
|
|
1000
|
+
colors = [COLORS["positive"] if x > 0 else COLORS["negative"] for x in quantile_means]
|
|
1001
|
+
|
|
1002
|
+
# Create figure
|
|
1003
|
+
fig = go.Figure()
|
|
1004
|
+
|
|
1005
|
+
# Add bar chart
|
|
1006
|
+
fig.add_trace(
|
|
1007
|
+
go.Bar(
|
|
1008
|
+
x=quantile_labels,
|
|
1009
|
+
y=quantile_means,
|
|
1010
|
+
marker={"color": colors, "line": {"color": "white", "width": 1}},
|
|
1011
|
+
hovertemplate="<b>%{x}</b><br>Mean Outcome: %{y:.4f}<extra></extra>",
|
|
1012
|
+
name="Mean Outcome",
|
|
1013
|
+
)
|
|
1014
|
+
)
|
|
1015
|
+
|
|
1016
|
+
# Add zero line
|
|
1017
|
+
fig.add_hline(y=0, line={"color": COLORS["grid"], "width": 1, "dash": "dash"})
|
|
1018
|
+
|
|
1019
|
+
# Build title and subtitle
|
|
1020
|
+
if title is None:
|
|
1021
|
+
title = "Monotonicity Analysis"
|
|
1022
|
+
|
|
1023
|
+
subtitle_parts = []
|
|
1024
|
+
if show_correlation:
|
|
1025
|
+
subtitle_parts.append(f"Correlation: {correlation:.3f} (p={p_value:.4f})")
|
|
1026
|
+
|
|
1027
|
+
subtitle_parts.append(f"Monotonicity: {monotonicity_score:.1%}")
|
|
1028
|
+
subtitle_parts.append(f"Direction: {direction.replace('_', ' ').title()}")
|
|
1029
|
+
|
|
1030
|
+
if is_monotonic:
|
|
1031
|
+
subtitle_parts.append("✓ Monotonic")
|
|
1032
|
+
else:
|
|
1033
|
+
subtitle_parts.append("✗ Not Monotonic")
|
|
1034
|
+
|
|
1035
|
+
subtitle = " | ".join(subtitle_parts)
|
|
1036
|
+
|
|
1037
|
+
# Update layout
|
|
1038
|
+
fig.update_layout(
|
|
1039
|
+
title={
|
|
1040
|
+
"text": f"<b>{title}</b><br><sub>{subtitle}</sub>",
|
|
1041
|
+
"x": 0.5,
|
|
1042
|
+
"xanchor": "center",
|
|
1043
|
+
},
|
|
1044
|
+
xaxis_title="Feature Quantile",
|
|
1045
|
+
yaxis_title="Mean Outcome",
|
|
1046
|
+
showlegend=False,
|
|
1047
|
+
**DEFAULT_LAYOUT,
|
|
1048
|
+
)
|
|
1049
|
+
|
|
1050
|
+
return fig
|