ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1050 @@
1
+ """Interactive visualizations for ml4t-diagnostic evaluation results.
2
+
3
+ This module provides Plotly-based visualizations for the Three-Tier
4
+ Validation Framework, including IC heatmaps, quantile analysis,
5
+ and comprehensive evaluation dashboards.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING, Any, Union, cast
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import plotly.express as px
13
+ import plotly.graph_objects as go
14
+ import polars as pl
15
+ from plotly.subplots import make_subplots
16
+
17
+ from ml4t.diagnostic.backends.polars_backend import PolarsBackend
18
+
19
+ if TYPE_CHECKING:
20
+ from numpy.typing import NDArray
21
+
22
+ # Color schemes for financial data
23
+ COLORS = {
24
+ "positive": "#00CC88", # Green for positive returns
25
+ "negative": "#FF4444", # Red for negative returns
26
+ "neutral": "#888888", # Gray for neutral
27
+ "primary": "#3366CC", # Blue for primary data
28
+ "secondary": "#FF9900", # Orange for secondary data
29
+ "background": "#F8F9FA",
30
+ "grid": "#E0E0E0",
31
+ }
32
+
33
+ # Plotly theme configuration
34
+ DEFAULT_LAYOUT = {
35
+ "font": {"family": "Arial, sans-serif", "size": 12},
36
+ "plot_bgcolor": COLORS["background"],
37
+ "paper_bgcolor": "white",
38
+ "hovermode": "closest",
39
+ "margin": {"l": 60, "r": 30, "t": 50, "b": 60},
40
+ "xaxis": {"gridcolor": COLORS["grid"], "zeroline": False},
41
+ "yaxis": {"gridcolor": COLORS["grid"], "zeroline": False},
42
+ }
43
+
44
+
45
+ def plot_ic_heatmap(
46
+ predictions: Union[pd.DataFrame, "NDArray[Any]"],
47
+ returns: Union[pd.DataFrame, "NDArray[Any]"],
48
+ horizons: list[int] | None = None,
49
+ time_index: pd.DatetimeIndex | None = None,
50
+ regime_column: str | None = None,
51
+ title: str = "Information Coefficient Term Structure",
52
+ colorscale: str = "RdBu",
53
+ _use_optimized: bool = True,
54
+ use_streaming: bool = True,
55
+ ) -> go.Figure:
56
+ """Create interactive IC heatmap across multiple forward return horizons.
57
+
58
+ This visualization shows how predictive power varies across different
59
+ prediction horizons, helping identify the optimal holding period.
60
+
61
+ Parameters
62
+ ----------
63
+ predictions : pd.DataFrame or ndarray
64
+ Model predictions (same for all horizons)
65
+ returns : pd.DataFrame or ndarray
66
+ Forward returns for different horizons (columns = horizons)
67
+ horizons : list[int], optional
68
+ List of forward return horizons. If None, uses column names
69
+ time_index : pd.DatetimeIndex, optional
70
+ Time index for x-axis. If None, uses integer index
71
+ regime_column : str, optional
72
+ Column name for market regime filtering
73
+ title : str, default "Information Coefficient Term Structure"
74
+ Plot title
75
+ colorscale : str, default "RdBu"
76
+ Plotly colorscale name
77
+ use_optimized : bool, default True
78
+ Whether to use optimized Polars backend (always True for performance)
79
+
80
+ Returns:
81
+ -------
82
+ go.Figure
83
+ Interactive Plotly figure
84
+
85
+ Examples:
86
+ --------
87
+ >>> # Simple usage
88
+ >>> fig = plot_ic_heatmap(predictions, forward_returns)
89
+ >>> fig.show()
90
+
91
+ >>> # With custom horizons
92
+ >>> fig = plot_ic_heatmap(
93
+ ... predictions,
94
+ ... returns_df,
95
+ ... horizons=[1, 5, 10, 20],
96
+ ... time_index=returns_df.index
97
+ ... )
98
+ """
99
+ # Convert inputs to appropriate types
100
+ predictions_data: pd.Series | pd.DataFrame | NDArray[Any]
101
+ if isinstance(predictions, np.ndarray):
102
+ predictions_data = pd.Series(predictions, name="predictions")
103
+ else:
104
+ predictions_data = predictions
105
+
106
+ returns_data: pd.DataFrame | NDArray[Any]
107
+ if isinstance(returns, np.ndarray):
108
+ returns_data = (
109
+ pd.DataFrame(returns, columns=cast(Any, ["returns"]))
110
+ if returns.ndim == 1
111
+ else pd.DataFrame(returns)
112
+ )
113
+ else:
114
+ returns_data = returns
115
+
116
+ # Determine horizons as strings for internal processing
117
+ horizons_str: list[str]
118
+ if horizons is None:
119
+ if isinstance(returns_data, pd.DataFrame):
120
+ horizons_str = [str(col) for col in returns_data.columns]
121
+ else:
122
+ horizons_str = ["1"]
123
+ else:
124
+ horizons_str = [str(h) for h in horizons]
125
+
126
+ # Calculate rolling IC for each horizon
127
+ window_size = min(60, len(predictions_data) // 4) # Adaptive window
128
+
129
+ # Convert Series to DataFrame for _compute_ic_matrix_optimized
130
+ pred_for_ic: pd.DataFrame | NDArray[Any]
131
+ if isinstance(predictions_data, pd.Series):
132
+ pred_for_ic = predictions_data.to_frame()
133
+ elif isinstance(predictions_data, pd.DataFrame):
134
+ pred_for_ic = predictions_data
135
+ else:
136
+ pred_for_ic = predictions_data
137
+
138
+ # Use optimized Polars implementation for all cases
139
+ ic_matrix = _compute_ic_matrix_optimized(
140
+ pred_for_ic,
141
+ returns_data,
142
+ horizons_str,
143
+ window_size,
144
+ use_streaming,
145
+ )
146
+
147
+ # Create time index
148
+ x_values: pd.Index | pd.DatetimeIndex
149
+ if time_index is not None:
150
+ x_values = time_index[window_size:]
151
+ else:
152
+ x_values = pd.Index(list(range(window_size, len(predictions_data))))
153
+
154
+ # Create heatmap
155
+ fig = go.Figure(
156
+ data=go.Heatmap(
157
+ z=ic_matrix,
158
+ x=x_values,
159
+ y=[f"{h}d" for h in horizons_str],
160
+ colorscale=colorscale,
161
+ zmid=0,
162
+ text=np.round(ic_matrix, 3),
163
+ texttemplate="%{text}",
164
+ textfont={"size": 10},
165
+ hovertemplate="Horizon: %{y}<br>Time: %{x}<br>IC: %{z:.3f}<extra></extra>",
166
+ colorbar={"title": "IC", "tickmode": "linear", "tick0": -1, "dtick": 0.2},
167
+ ),
168
+ )
169
+
170
+ # Update layout
171
+ fig.update_layout(
172
+ title={"text": title, "x": 0.5, "xanchor": "center"},
173
+ xaxis_title="Date" if time_index is not None else "Time",
174
+ yaxis_title="Forward Return Horizon",
175
+ **DEFAULT_LAYOUT,
176
+ )
177
+
178
+ # Add regime filtering if specified
179
+ if regime_column is not None:
180
+ # This would add dropdown for regime filtering
181
+ # Implementation depends on regime data structure
182
+ pass
183
+
184
+ return fig
185
+
186
+
187
+ def _compute_ic_matrix_optimized(
188
+ predictions: Union[pd.DataFrame, "NDArray[Any]"],
189
+ returns: Union[pd.DataFrame, "NDArray[Any]"],
190
+ horizons: list[str],
191
+ window_size: int,
192
+ use_streaming: bool = True,
193
+ ) -> list[list[float]]:
194
+ """Compute IC matrix using optimized Polars operations with streaming for large datasets.
195
+
196
+ Parameters
197
+ ----------
198
+ predictions : Union[pd.DataFrame, NDArray]
199
+ Model predictions
200
+ returns : Union[pd.DataFrame, NDArray]
201
+ Returns data for different horizons
202
+ horizons : list[str]
203
+ List of horizon labels
204
+ window_size : int
205
+ Rolling window size for IC calculation
206
+ use_streaming : bool, default True
207
+ Whether to use streaming for large datasets (>100k samples)
208
+
209
+ Returns
210
+ -------
211
+ list[list[float]]
212
+ IC matrix with shape (n_horizons, n_time_points)
213
+ """
214
+ # Convert to Polars DataFrame
215
+ data_dict: dict[str, NDArray[Any]] = {}
216
+
217
+ # Handle predictions
218
+ if isinstance(predictions, np.ndarray):
219
+ pred_array = predictions.flatten()
220
+ elif hasattr(predictions, "values"):
221
+ pred_array = predictions.values.flatten()
222
+ else:
223
+ pred_array = np.array(predictions).flatten()
224
+
225
+ data_dict["predictions"] = pred_array
226
+ n_samples = len(pred_array)
227
+
228
+ # Handle returns
229
+ if isinstance(returns, pd.DataFrame):
230
+ for i, horizon in enumerate(horizons):
231
+ if i < returns.shape[1]:
232
+ data_dict[f"returns_{horizon}"] = returns.iloc[:, i].to_numpy()
233
+ else:
234
+ data_dict[f"returns_{horizon}"] = returns.iloc[:, 0].to_numpy()
235
+ elif isinstance(returns, np.ndarray):
236
+ if returns.ndim == 2:
237
+ for i, horizon in enumerate(horizons):
238
+ if i < returns.shape[1]:
239
+ data_dict[f"returns_{horizon}"] = returns[:, i]
240
+ else:
241
+ data_dict[f"returns_{horizon}"] = returns[:, 0]
242
+ else:
243
+ for horizon in horizons:
244
+ data_dict[f"returns_{horizon}"] = returns
245
+ else:
246
+ # Assume single series
247
+ ret_array = np.array(returns).flatten()
248
+ for horizon in horizons:
249
+ data_dict[f"returns_{horizon}"] = ret_array
250
+
251
+ # Create Polars DataFrame
252
+ df = pl.DataFrame(data_dict)
253
+
254
+ # Choose appropriate method based on dataset size and streaming preference
255
+ returns_matrix = df.select([f"returns_{h}" for h in horizons])
256
+ min_periods = max(2, window_size // 2)
257
+
258
+ if use_streaming and n_samples > 100000:
259
+ # Use streaming method for large datasets
260
+ ic_results = PolarsBackend.fast_multi_horizon_ic_streaming(
261
+ df["predictions"],
262
+ returns_matrix,
263
+ window_size,
264
+ min_periods=min_periods,
265
+ chunk_size=PolarsBackend.adaptive_chunk_size(
266
+ n_samples,
267
+ len(horizons) + 1,
268
+ target_memory_mb=500,
269
+ ),
270
+ )
271
+ else:
272
+ # Use standard method for smaller datasets
273
+ ic_results = PolarsBackend.fast_multi_horizon_ic(
274
+ df["predictions"],
275
+ returns_matrix,
276
+ window_size,
277
+ min_periods=min_periods,
278
+ )
279
+
280
+ # Extract IC matrix
281
+ ic_matrix = []
282
+ for horizon in horizons:
283
+ ic_series = ic_results[f"ic_returns_{horizon}"]
284
+ # Remove initial NaN values and convert to list
285
+ ic_values = ic_series.drop_nulls().to_list()
286
+ # Trim to remove window startup
287
+ if len(ic_values) > window_size:
288
+ ic_values = ic_values[window_size:]
289
+ ic_matrix.append(ic_values)
290
+
291
+ return ic_matrix
292
+
293
+
294
+ def plot_quantile_returns(
295
+ predictions: Union[pd.Series, "NDArray[Any]"],
296
+ returns: Union[pd.Series, "NDArray[Any]"],
297
+ n_quantiles: int = 5,
298
+ show_cumulative: bool = True,
299
+ title: str = "Returns by Prediction Quantile",
300
+ ) -> go.Figure:
301
+ """Create quantile bar chart with optional cumulative returns.
302
+
303
+ This visualization shows average returns for each prediction quantile,
304
+ helping validate monotonic relationships between predictions and outcomes.
305
+
306
+ Parameters
307
+ ----------
308
+ predictions : pd.Series or ndarray
309
+ Model predictions
310
+ returns : pd.Series or ndarray
311
+ Actual returns
312
+ n_quantiles : int, default 5
313
+ Number of quantiles to create
314
+ show_cumulative : bool, default True
315
+ Whether to show cumulative returns subplot
316
+ title : str
317
+ Plot title
318
+
319
+ Returns:
320
+ -------
321
+ go.Figure
322
+ Interactive Plotly figure with quantile analysis
323
+ """
324
+ # Store original index if available
325
+ time_index = None
326
+ if isinstance(returns, pd.Series):
327
+ time_index = returns.index
328
+ elif isinstance(predictions, pd.Series):
329
+ time_index = predictions.index
330
+
331
+ # Convert to numpy arrays for consistent processing
332
+ pred_arr: NDArray[Any]
333
+ ret_arr: NDArray[Any]
334
+ if isinstance(predictions, pd.Series):
335
+ pred_arr = predictions.to_numpy()
336
+ else:
337
+ pred_arr = predictions
338
+ if isinstance(returns, pd.Series):
339
+ ret_arr = returns.to_numpy()
340
+ else:
341
+ ret_arr = returns
342
+
343
+ # Handle edge cases
344
+ if len(pred_arr) == 0 or len(ret_arr) == 0:
345
+ # Return empty figure
346
+ fig = go.Figure()
347
+ fig.update_layout(title=title)
348
+ return fig
349
+
350
+ # Check for all NaN
351
+ if np.all(np.isnan(pred_arr)) or np.all(np.isnan(ret_arr)):
352
+ # Return empty figure with message
353
+ fig = go.Figure()
354
+ fig.add_annotation(
355
+ text="No valid data to display",
356
+ xref="paper",
357
+ yref="paper",
358
+ x=0.5,
359
+ y=0.5,
360
+ showarrow=False,
361
+ )
362
+ fig.update_layout(title=title)
363
+ return fig
364
+
365
+ # Create quantiles
366
+ quantile_labels: NDArray[Any]
367
+ try:
368
+ quantile_result = pd.qcut(pred_arr, n_quantiles, labels=False, duplicates="drop") + 1
369
+ quantile_labels = (
370
+ quantile_result.to_numpy()
371
+ if hasattr(quantile_result, "to_numpy")
372
+ else np.array(quantile_result)
373
+ )
374
+ except ValueError:
375
+ # If can't create quantiles, use equal splits
376
+ quantile_labels = np.linspace(1, n_quantiles, len(pred_arr), dtype=int)
377
+
378
+ # Calculate mean returns per quantile
379
+ quantile_returns = []
380
+ quantile_counts: list[int] = []
381
+ std_errors = []
382
+
383
+ for q in range(1, n_quantiles + 1):
384
+ mask = quantile_labels == q
385
+ q_returns = ret_arr[mask]
386
+ quantile_returns.append(np.mean(q_returns))
387
+ quantile_counts.append(np.sum(mask))
388
+ std_errors.append(np.std(q_returns) / np.sqrt(len(q_returns)))
389
+
390
+ # Create figure
391
+ if show_cumulative:
392
+ fig = make_subplots(
393
+ rows=2,
394
+ cols=1,
395
+ row_heights=[0.6, 0.4],
396
+ shared_xaxes=True,
397
+ vertical_spacing=0.1,
398
+ subplot_titles=("Mean Returns by Quantile", "Cumulative Returns"),
399
+ )
400
+ else:
401
+ fig = go.Figure()
402
+
403
+ # Colors based on return sign
404
+ colors = [COLORS["positive"] if r > 0 else COLORS["negative"] for r in quantile_returns]
405
+
406
+ # Add bar chart
407
+ bar_trace = go.Bar(
408
+ x=list(range(1, n_quantiles + 1)),
409
+ y=quantile_returns,
410
+ error_y={"type": "data", "array": std_errors, "visible": True},
411
+ marker_color=colors,
412
+ text=[f"{r:.2%}" for r in quantile_returns],
413
+ textposition="outside",
414
+ hovertemplate=(
415
+ "Quantile %{x}<br>Mean Return: %{y:.2%}<br>Count: %{customdata}<extra></extra>"
416
+ ),
417
+ customdata=quantile_counts,
418
+ name="Mean Return",
419
+ showlegend=False,
420
+ )
421
+
422
+ if show_cumulative:
423
+ fig.add_trace(bar_trace, row=1, col=1)
424
+
425
+ # Calculate cumulative returns for each quantile with proper time alignment
426
+ for q in range(1, n_quantiles + 1):
427
+ mask = quantile_labels == q
428
+
429
+ # If we have a time index, use it for proper alignment
430
+ if time_index is not None:
431
+ # Get returns and their corresponding times
432
+ q_indices = np.where(mask)[0]
433
+ # Convert to numpy to avoid pandas index issues with positional sorting
434
+ q_returns_arr = ret_arr[mask]
435
+ q_times = time_index[q_indices]
436
+
437
+ # Sort by time
438
+ time_order = np.argsort(q_times)
439
+ q_returns_sorted = q_returns_arr[time_order]
440
+ q_times_sorted = q_times[time_order]
441
+
442
+ # Calculate cumulative returns on time-sorted data
443
+ cumulative = np.cumprod(1 + q_returns_sorted) - 1
444
+
445
+ fig.add_trace(
446
+ go.Scatter(
447
+ x=q_times_sorted,
448
+ y=cumulative,
449
+ mode="lines",
450
+ name=f"Q{q}",
451
+ line={"width": 2},
452
+ hovertemplate=(
453
+ "Quantile %{fullData.name}<br>Time: %{x}<br>Cumulative: %{y:.2%}<extra></extra>"
454
+ ),
455
+ ),
456
+ row=2,
457
+ col=1,
458
+ )
459
+ else:
460
+ # Fallback to position-based if no time index
461
+ q_returns_arr = ret_arr[mask]
462
+ cumulative = np.cumprod(1 + q_returns_arr) - 1
463
+
464
+ fig.add_trace(
465
+ go.Scatter(
466
+ x=np.arange(len(cumulative)),
467
+ y=cumulative,
468
+ mode="lines",
469
+ name=f"Q{q}",
470
+ line={"width": 2},
471
+ hovertemplate=(
472
+ "Quantile %{fullData.name}<br>Position: %{x}<br>Cumulative: %{y:.2%}<extra></extra>"
473
+ ),
474
+ ),
475
+ row=2,
476
+ col=1,
477
+ )
478
+ else:
479
+ fig.add_trace(bar_trace)
480
+
481
+ # Update layout
482
+ fig.update_xaxes(
483
+ title_text="Prediction Quantile",
484
+ row=2 if show_cumulative else 1,
485
+ col=1,
486
+ )
487
+ fig.update_yaxes(title_text="Mean Return", tickformat=".1%", row=1, col=1)
488
+
489
+ if show_cumulative:
490
+ fig.update_yaxes(title_text="Cumulative Return", tickformat=".1%", row=2, col=1)
491
+ # Update x-axis label based on whether we have time index
492
+ x_label = "Time" if time_index is not None else "Position"
493
+ fig.update_xaxes(title_text=x_label, row=2, col=1)
494
+
495
+ fig.update_layout(title={"text": title, "x": 0.5, "xanchor": "center"}, **DEFAULT_LAYOUT)
496
+
497
+ return fig
498
+
499
+
500
+ def plot_turnover_decay(
501
+ factor_values: pd.DataFrame,
502
+ quantiles: int = 5,
503
+ lags: list[int] | None = None,
504
+ title: str = "Factor Turnover and Decay Analysis",
505
+ ) -> go.Figure:
506
+ """Visualize factor stability through turnover and autocorrelation analysis.
507
+
508
+ Parameters
509
+ ----------
510
+ factor_values : pd.DataFrame
511
+ Time series of factor values (index = time, columns = assets)
512
+ quantiles : int, default 5
513
+ Number of quantiles for turnover calculation
514
+ lags : list[int], optional
515
+ Autocorrelation lags to compute. Default [1, 5, 10, 20]
516
+ title : str
517
+ Plot title
518
+
519
+ Returns:
520
+ -------
521
+ go.Figure
522
+ Multi-panel figure showing turnover and decay analysis
523
+ """
524
+ if lags is None:
525
+ lags = [1, 5, 10, 20]
526
+
527
+ # Create subplots
528
+ fig = make_subplots(
529
+ rows=2,
530
+ cols=2,
531
+ subplot_titles=(
532
+ "Quantile Turnover by Period",
533
+ "Average Autocorrelation Decay",
534
+ "Turnover Heatmap",
535
+ "Signal Stability",
536
+ ),
537
+ specs=[
538
+ [{"type": "bar"}, {"type": "scatter"}],
539
+ [{"type": "heatmap"}, {"type": "scatter"}],
540
+ ],
541
+ )
542
+
543
+ # Calculate quantile assignments
544
+ quantile_assignments = factor_values.apply(
545
+ lambda x: pd.qcut(x, quantiles, labels=False, duplicates="drop"),
546
+ axis=0,
547
+ )
548
+
549
+ # 1. Calculate turnover for each quantile
550
+ turnover_by_quantile = []
551
+ for q in range(quantiles):
552
+ # Count changes in quantile assignment
553
+ in_quantile = (quantile_assignments == q).astype(int)
554
+ changes = in_quantile.diff().abs().sum(axis=1)
555
+ total = in_quantile.sum(axis=1)
556
+ turnover = (changes / (2 * total)).fillna(0).mean()
557
+ turnover_by_quantile.append(turnover)
558
+
559
+ # Add turnover bar chart
560
+ fig.add_trace(
561
+ go.Bar(
562
+ x=list(range(1, quantiles + 1)),
563
+ y=turnover_by_quantile,
564
+ marker_color=COLORS["primary"],
565
+ text=[f"{t:.1%}" for t in turnover_by_quantile],
566
+ textposition="outside",
567
+ hovertemplate="Quantile %{x}<br>Turnover: %{y:.1%}<extra></extra>",
568
+ showlegend=False,
569
+ ),
570
+ row=1,
571
+ col=1,
572
+ )
573
+
574
+ # 2. Calculate autocorrelation decay
575
+ autocorr_values = []
576
+ for lag in lags:
577
+ # Calculate autocorrelation for each asset
578
+ autocorr = factor_values.apply(
579
+ lambda x, current_lag=lag: x.autocorr(lag=current_lag),
580
+ )
581
+ autocorr_values.append(autocorr.mean())
582
+
583
+ # Add autocorrelation decay plot
584
+ fig.add_trace(
585
+ go.Scatter(
586
+ x=lags,
587
+ y=autocorr_values,
588
+ mode="lines+markers",
589
+ marker={"size": 10, "color": COLORS["secondary"]},
590
+ line={"width": 3, "color": COLORS["secondary"]},
591
+ hovertemplate="Lag %{x}<br>Autocorr: %{y:.3f}<extra></extra>",
592
+ showlegend=False,
593
+ ),
594
+ row=1,
595
+ col=2,
596
+ )
597
+
598
+ # 3. Turnover heatmap (time vs quantile)
599
+ # Sample time periods for visualization
600
+ time_periods = min(20, len(factor_values) // 10)
601
+ sample_indices = np.linspace(0, len(factor_values) - 2, time_periods, dtype=int)
602
+
603
+ turnover_matrix = []
604
+ for idx in sample_indices:
605
+ period_turnover = []
606
+ for q in range(quantiles):
607
+ in_q_t0 = quantile_assignments.iloc[idx] == q
608
+ in_q_t1 = quantile_assignments.iloc[idx + 1] == q
609
+ stayed = (in_q_t0 & in_q_t1).sum()
610
+ total = in_q_t0.sum()
611
+ turnover = 1 - (stayed / total) if total > 0 else 0
612
+ period_turnover.append(turnover)
613
+ turnover_matrix.append(period_turnover)
614
+
615
+ fig.add_trace(
616
+ go.Heatmap(
617
+ z=turnover_matrix,
618
+ x=list(range(1, quantiles + 1)),
619
+ y=sample_indices,
620
+ colorscale="Reds",
621
+ hovertemplate=("Time: %{y}<br>Quantile: %{x}<br>Turnover: %{z:.1%}<extra></extra>"),
622
+ showscale=True,
623
+ colorbar={"title": "Turnover", "x": 1.15},
624
+ ),
625
+ row=2,
626
+ col=1,
627
+ )
628
+
629
+ # 4. Signal stability (rolling mean of factor values)
630
+ rolling_mean = factor_values.mean(axis=1).rolling(window=20).mean()
631
+ rolling_std = factor_values.mean(axis=1).rolling(window=20).std()
632
+
633
+ fig.add_trace(
634
+ go.Scatter(
635
+ x=factor_values.index,
636
+ y=rolling_mean,
637
+ mode="lines",
638
+ line={"color": COLORS["primary"], "width": 2},
639
+ name="Rolling Mean",
640
+ hovertemplate="Time: %{x}<br>Mean: %{y:.3f}<extra></extra>",
641
+ ),
642
+ row=2,
643
+ col=2,
644
+ )
645
+
646
+ # Add confidence bands
647
+ fig.add_trace(
648
+ go.Scatter(
649
+ x=factor_values.index,
650
+ y=rolling_mean + 2 * rolling_std,
651
+ mode="lines",
652
+ line={"width": 0},
653
+ showlegend=False,
654
+ hoverinfo="skip",
655
+ ),
656
+ row=2,
657
+ col=2,
658
+ )
659
+
660
+ fig.add_trace(
661
+ go.Scatter(
662
+ x=factor_values.index,
663
+ y=rolling_mean - 2 * rolling_std,
664
+ mode="lines",
665
+ line={"width": 0},
666
+ fill="tonexty",
667
+ fillcolor="rgba(51, 102, 204, 0.2)",
668
+ name="±2 STD",
669
+ hoverinfo="skip",
670
+ ),
671
+ row=2,
672
+ col=2,
673
+ )
674
+
675
+ # Update axes
676
+ fig.update_xaxes(title_text="Quantile", row=1, col=1)
677
+ fig.update_yaxes(title_text="Turnover Rate", tickformat=".0%", row=1, col=1)
678
+
679
+ fig.update_xaxes(title_text="Lag (days)", row=1, col=2)
680
+ fig.update_yaxes(title_text="Autocorrelation", row=1, col=2)
681
+
682
+ fig.update_xaxes(title_text="Quantile", row=2, col=1)
683
+ fig.update_yaxes(title_text="Time Period", row=2, col=1)
684
+
685
+ fig.update_xaxes(title_text="Date", row=2, col=2)
686
+ fig.update_yaxes(title_text="Factor Value", row=2, col=2)
687
+
688
+ # Update layout
689
+ fig.update_layout(
690
+ title={"text": title, "x": 0.5, "xanchor": "center"},
691
+ height=800,
692
+ **DEFAULT_LAYOUT,
693
+ )
694
+
695
+ return fig
696
+
697
+
698
+ def plot_feature_distributions(
699
+ features: pd.DataFrame,
700
+ n_periods: int = 4,
701
+ method: str = "box",
702
+ title: str = "Feature Distribution Analysis",
703
+ ) -> go.Figure:
704
+ """Create small multiples showing feature distributions over time.
705
+
706
+ Parameters
707
+ ----------
708
+ features : pd.DataFrame
709
+ Feature values (index = time, columns = features)
710
+ n_periods : int, default 4
711
+ Number of time periods to show
712
+ method : str, default "box"
713
+ Plot type: "box", "violin", or "hist"
714
+ title : str
715
+ Plot title
716
+
717
+ Returns:
718
+ -------
719
+ go.Figure
720
+ Small multiples visualization
721
+ """
722
+ # Limit to first 9 features for readability
723
+ n_features = min(9, features.shape[1])
724
+ feature_cols = features.columns[:n_features]
725
+
726
+ # Create time buckets
727
+ period_size = len(features) // n_periods
728
+ periods = []
729
+ period_labels = []
730
+
731
+ for i in range(n_periods):
732
+ start_idx = i * period_size
733
+ end_idx = (i + 1) * period_size if i < n_periods - 1 else len(features)
734
+ periods.append((start_idx, end_idx))
735
+
736
+ if hasattr(features.index, "date"):
737
+ start_date = features.index[start_idx].strftime("%Y-%m")
738
+ end_date = features.index[end_idx - 1].strftime("%Y-%m")
739
+ period_labels.append(f"{start_date} to {end_date}")
740
+ else:
741
+ period_labels.append(f"Period {i + 1}")
742
+
743
+ # Create subplots
744
+ n_rows = int(np.ceil(n_features / 3))
745
+ fig = make_subplots(
746
+ rows=n_rows,
747
+ cols=3,
748
+ subplot_titles=[str(col) for col in feature_cols],
749
+ vertical_spacing=0.15,
750
+ horizontal_spacing=0.1,
751
+ )
752
+
753
+ # Add plots for each feature
754
+ for idx, feature in enumerate(feature_cols):
755
+ row = idx // 3 + 1
756
+ col = idx % 3 + 1
757
+
758
+ for period_idx, (start, end) in enumerate(periods):
759
+ period_data = features[feature].iloc[start:end]
760
+
761
+ if method == "box":
762
+ fig.add_trace(
763
+ go.Box(
764
+ y=period_data,
765
+ name=period_labels[period_idx],
766
+ marker_color=px.colors.qualitative.Set3[period_idx],
767
+ boxpoints="outliers",
768
+ showlegend=(idx == 0),
769
+ legendgroup=f"period{period_idx}",
770
+ hovertemplate="%{y:.3f}<extra></extra>",
771
+ ),
772
+ row=row,
773
+ col=col,
774
+ )
775
+
776
+ elif method == "violin":
777
+ fig.add_trace(
778
+ go.Violin(
779
+ y=period_data,
780
+ name=period_labels[period_idx],
781
+ marker_color=px.colors.qualitative.Set3[period_idx],
782
+ box_visible=True,
783
+ meanline_visible=True,
784
+ showlegend=(idx == 0),
785
+ legendgroup=f"period{period_idx}",
786
+ hovertemplate="%{y:.3f}<extra></extra>",
787
+ ),
788
+ row=row,
789
+ col=col,
790
+ )
791
+
792
+ elif method == "hist":
793
+ fig.add_trace(
794
+ go.Histogram(
795
+ x=period_data,
796
+ name=period_labels[period_idx],
797
+ marker_color=px.colors.qualitative.Set3[period_idx],
798
+ opacity=0.7,
799
+ showlegend=(idx == 0),
800
+ legendgroup=f"period{period_idx}",
801
+ hovertemplate="Value: %{x:.3f}<br>Count: %{y}<extra></extra>",
802
+ histnorm="probability",
803
+ ),
804
+ row=row,
805
+ col=col,
806
+ )
807
+
808
+ # Update layout
809
+ fig.update_layout(
810
+ title={"text": title, "x": 0.5, "xanchor": "center"},
811
+ height=300 * n_rows,
812
+ showlegend=True,
813
+ legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
814
+ **DEFAULT_LAYOUT,
815
+ )
816
+
817
+ # Update axes
818
+ if method == "hist":
819
+ fig.update_xaxes(title_text="Value")
820
+ fig.update_yaxes(title_text="Probability")
821
+ else:
822
+ fig.update_yaxes(title_text="Value")
823
+
824
+ return fig
825
+
826
+
827
+ def plot_ic_decay(
828
+ decay_results: dict[str, Any],
829
+ show_half_life: bool = True,
830
+ show_optimal: bool = True,
831
+ title: str | None = None,
832
+ ) -> go.Figure:
833
+ """Plot IC decay curve with half-life and optimal horizon annotations.
834
+
835
+ Creates an interactive Plotly visualization showing how IC decays across
836
+ prediction horizons, with optional markers for half-life and optimal horizon.
837
+
838
+ Parameters
839
+ ----------
840
+ decay_results : dict
841
+ Results from compute_ic_decay()
842
+ show_half_life : bool, default True
843
+ Show vertical line at estimated half-life
844
+ show_optimal : bool, default True
845
+ Show marker at optimal horizon
846
+ title : str | None, default None
847
+ Custom title for the plot. If None, uses "IC Decay Analysis"
848
+
849
+ Returns
850
+ -------
851
+ plotly.graph_objects.Figure
852
+ Interactive Plotly figure
853
+
854
+ Examples
855
+ --------
856
+ >>> from ml4t.diagnostic.evaluation.metrics import compute_ic_decay
857
+ >>> from ml4t.diagnostic.evaluation.visualization import plot_ic_decay
858
+ >>>
859
+ >>> # Compute decay
860
+ >>> decay = compute_ic_decay(pred_df, price_df, group_col="symbol")
861
+ >>>
862
+ >>> # Visualize
863
+ >>> fig = plot_ic_decay(decay)
864
+ >>> fig.show()
865
+ """
866
+ horizons = decay_results["horizons"]
867
+ ic_by_horizon = decay_results["ic_by_horizon"]
868
+ half_life = decay_results.get("half_life")
869
+ optimal_horizon = decay_results.get("optimal_horizon")
870
+
871
+ # Extract IC values in order
872
+ ic_values = [ic_by_horizon[h] for h in horizons]
873
+
874
+ # Create figure
875
+ fig = go.Figure()
876
+
877
+ # Add IC decay curve
878
+ fig.add_trace(
879
+ go.Scatter(
880
+ x=horizons,
881
+ y=ic_values,
882
+ mode="lines+markers",
883
+ name="IC",
884
+ line={"color": COLORS["primary"], "width": 2},
885
+ marker={"size": 8, "color": COLORS["primary"]},
886
+ hovertemplate="Horizon: %{x} days<br>IC: %{y:.4f}<extra></extra>",
887
+ )
888
+ )
889
+
890
+ # Add zero line for reference
891
+ fig.add_hline(y=0, line={"color": COLORS["grid"], "width": 1, "dash": "dash"})
892
+
893
+ # Add half-life marker
894
+ if show_half_life and half_life is not None:
895
+ # Calculate IC at half-life for the marker
896
+ if horizons[0] in ic_by_horizon:
897
+ initial_ic = ic_by_horizon[horizons[0]]
898
+ half_life_ic = initial_ic * 0.5
899
+
900
+ fig.add_vline(
901
+ x=half_life,
902
+ line={"color": COLORS["secondary"], "width": 2, "dash": "dash"},
903
+ annotation_text=f"Half-life: {half_life:.1f}d",
904
+ annotation_position="top right",
905
+ )
906
+
907
+ # Add marker at half-life point
908
+ fig.add_trace(
909
+ go.Scatter(
910
+ x=[half_life],
911
+ y=[half_life_ic],
912
+ mode="markers",
913
+ name="Half-life",
914
+ marker={"size": 12, "color": COLORS["secondary"], "symbol": "diamond"},
915
+ hovertemplate=f"Half-life: {half_life:.1f} days<br>IC: {half_life_ic:.4f}<extra></extra>",
916
+ )
917
+ )
918
+
919
+ # Add optimal horizon marker
920
+ if show_optimal and optimal_horizon is not None:
921
+ optimal_ic = ic_by_horizon[optimal_horizon]
922
+
923
+ fig.add_trace(
924
+ go.Scatter(
925
+ x=[optimal_horizon],
926
+ y=[optimal_ic],
927
+ mode="markers",
928
+ name="Optimal",
929
+ marker={
930
+ "size": 15,
931
+ "color": COLORS["positive"],
932
+ "symbol": "star",
933
+ "line": {"width": 2, "color": "white"},
934
+ },
935
+ hovertemplate=f"Optimal: {optimal_horizon} days<br>IC: {optimal_ic:.4f}<extra></extra>",
936
+ )
937
+ )
938
+
939
+ # Update layout
940
+ if title is None:
941
+ title = "IC Decay Analysis"
942
+
943
+ fig.update_layout(
944
+ title=title,
945
+ xaxis_title="Forecast Horizon (days)",
946
+ yaxis_title="Information Coefficient",
947
+ showlegend=True,
948
+ legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
949
+ **DEFAULT_LAYOUT,
950
+ )
951
+
952
+ return fig
953
+
954
+
955
+ def plot_monotonicity(
956
+ monotonicity_results: dict[str, Any],
957
+ title: str | None = None,
958
+ show_correlation: bool = True,
959
+ ) -> go.Figure:
960
+ """Plot quantile analysis for monotonicity testing.
961
+
962
+ Creates a bar chart showing mean outcomes across feature quantiles,
963
+ with annotations for monotonicity metrics.
964
+
965
+ Parameters
966
+ ----------
967
+ monotonicity_results : dict
968
+ Results from compute_monotonicity()
969
+ title : str | None, default None
970
+ Custom title. If None, uses "Monotonicity Analysis"
971
+ show_correlation : bool, default True
972
+ Show correlation coefficient in subtitle
973
+
974
+ Returns
975
+ -------
976
+ plotly.graph_objects.Figure
977
+ Interactive Plotly figure
978
+
979
+ Examples
980
+ --------
981
+ >>> from ml4t.diagnostic.evaluation.metrics import compute_monotonicity
982
+ >>> from ml4t.diagnostic.evaluation.visualization import plot_monotonicity
983
+ >>>
984
+ >>> # Compute monotonicity
985
+ >>> result = compute_monotonicity(features, outcomes, n_quantiles=5)
986
+ >>>
987
+ >>> # Visualize
988
+ >>> fig = plot_monotonicity(result)
989
+ >>> fig.show()
990
+ """
991
+ quantile_labels = monotonicity_results["quantile_labels"]
992
+ quantile_means = monotonicity_results["quantile_means"]
993
+ correlation = monotonicity_results["correlation"]
994
+ p_value = monotonicity_results["p_value"]
995
+ is_monotonic = monotonicity_results["is_monotonic"]
996
+ monotonicity_score = monotonicity_results["monotonicity_score"]
997
+ direction = monotonicity_results["direction"]
998
+
999
+ # Determine bar colors based on values
1000
+ colors = [COLORS["positive"] if x > 0 else COLORS["negative"] for x in quantile_means]
1001
+
1002
+ # Create figure
1003
+ fig = go.Figure()
1004
+
1005
+ # Add bar chart
1006
+ fig.add_trace(
1007
+ go.Bar(
1008
+ x=quantile_labels,
1009
+ y=quantile_means,
1010
+ marker={"color": colors, "line": {"color": "white", "width": 1}},
1011
+ hovertemplate="<b>%{x}</b><br>Mean Outcome: %{y:.4f}<extra></extra>",
1012
+ name="Mean Outcome",
1013
+ )
1014
+ )
1015
+
1016
+ # Add zero line
1017
+ fig.add_hline(y=0, line={"color": COLORS["grid"], "width": 1, "dash": "dash"})
1018
+
1019
+ # Build title and subtitle
1020
+ if title is None:
1021
+ title = "Monotonicity Analysis"
1022
+
1023
+ subtitle_parts = []
1024
+ if show_correlation:
1025
+ subtitle_parts.append(f"Correlation: {correlation:.3f} (p={p_value:.4f})")
1026
+
1027
+ subtitle_parts.append(f"Monotonicity: {monotonicity_score:.1%}")
1028
+ subtitle_parts.append(f"Direction: {direction.replace('_', ' ').title()}")
1029
+
1030
+ if is_monotonic:
1031
+ subtitle_parts.append("✓ Monotonic")
1032
+ else:
1033
+ subtitle_parts.append("✗ Not Monotonic")
1034
+
1035
+ subtitle = " | ".join(subtitle_parts)
1036
+
1037
+ # Update layout
1038
+ fig.update_layout(
1039
+ title={
1040
+ "text": f"<b>{title}</b><br><sub>{subtitle}</sub>",
1041
+ "x": 0.5,
1042
+ "xanchor": "center",
1043
+ },
1044
+ xaxis_title="Feature Quantile",
1045
+ yaxis_title="Mean Outcome",
1046
+ showlegend=False,
1047
+ **DEFAULT_LAYOUT,
1048
+ )
1049
+
1050
+ return fig