ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,618 @@
1
+ """Feature interaction visualization functions.
2
+
3
+ This module provides functions for visualizing feature interaction analysis results
4
+ from analyze_interactions(), compute_shap_interactions(), and related functions.
5
+
6
+ All plot functions follow the standard API defined in docs/plot_api_standards.md:
7
+ - Consume results dicts from analyze_*() or compute_*() functions
8
+ - Return plotly.graph_objects.Figure instances
9
+ - Support theme customization via global or per-plot settings
10
+ - Use keyword-only arguments (after results)
11
+ - Provide comprehensive hover information and interactivity
12
+
13
+ Example workflow:
14
+ >>> from ml4t.diagnostic.evaluation import analyze_interactions, compute_shap_interactions
15
+ >>> from ml4t.diagnostic.visualization import (
16
+ ... plot_interaction_bar,
17
+ ... plot_interaction_heatmap,
18
+ ... plot_interaction_network,
19
+ ... set_plot_theme
20
+ ... )
21
+ >>>
22
+ >>> # Analyze interactions
23
+ >>> results = analyze_interactions(model, X, y)
24
+ >>>
25
+ >>> # Or use SHAP directly
26
+ >>> shap_results = compute_shap_interactions(model, X, top_k=20)
27
+ >>>
28
+ >>> # Create visualizations
29
+ >>> fig_bar = plot_interaction_bar(shap_results, top_n=15)
30
+ >>> fig_heatmap = plot_interaction_heatmap(shap_results)
31
+ >>> fig_network = plot_interaction_network(shap_results, threshold=0.01)
32
+ >>>
33
+ >>> # Display or save
34
+ >>> fig_network.show()
35
+ >>> fig_heatmap.write_html("interactions_report.html")
36
+ """
37
+
38
+ from typing import Any
39
+
40
+ import numpy as np
41
+ import plotly.graph_objects as go
42
+
43
+ from ml4t.diagnostic.visualization.core import (
44
+ apply_responsive_layout,
45
+ get_color_scheme,
46
+ get_colorscale,
47
+ get_theme_config,
48
+ validate_plot_results,
49
+ validate_positive_int,
50
+ validate_theme,
51
+ )
52
+
53
+ __all__ = [
54
+ "plot_interaction_bar",
55
+ "plot_interaction_heatmap",
56
+ "plot_interaction_network",
57
+ ]
58
+
59
+
60
+ def plot_interaction_bar(
61
+ results: dict[str, Any],
62
+ *,
63
+ title: str | None = None,
64
+ top_n: int | None = 20,
65
+ theme: str | None = None,
66
+ color_scheme: str | None = None,
67
+ width: int | None = None,
68
+ height: int | None = None,
69
+ show_values: bool = True,
70
+ ) -> go.Figure:
71
+ """Plot horizontal bar chart of top feature interactions.
72
+
73
+ Creates an interactive bar chart showing the strongest feature interactions
74
+ ranked by their interaction strength. Each bar represents a feature pair
75
+ with color-coding by strength.
76
+
77
+ Parameters
78
+ ----------
79
+ results : dict[str, Any]
80
+ Results from compute_shap_interactions() or analyze_interactions() containing:
81
+ - "top_interactions": list[tuple[str, str, float]] - Feature pairs with scores
82
+ OR
83
+ - "consensus_ranking": list[tuple[str, str, float, dict]] - From analyze_interactions()
84
+ title : str | None, optional
85
+ Plot title. If None, uses "Feature Interactions - Top Pairs"
86
+ top_n : int | None, optional
87
+ Number of top interactions to display. If None, shows all.
88
+ Default is 20 to avoid overcrowding.
89
+ theme : str | None, optional
90
+ Theme name ("default", "dark", "print", "presentation").
91
+ If None, uses current global theme.
92
+ color_scheme : str | None, optional
93
+ Color scheme for bars. If None, uses "viridis".
94
+ Recommended: "viridis", "cividis", "plasma", "oranges", "reds"
95
+ width : int | None, optional
96
+ Figure width in pixels. If None, uses theme default (typically 1000).
97
+ height : int | None, optional
98
+ Figure height in pixels. If None, auto-sizes based on interaction count
99
+ (25px per interaction + 100px padding).
100
+ show_values : bool, optional
101
+ Whether to show interaction values on bars. Default is True.
102
+
103
+ Returns
104
+ -------
105
+ go.Figure
106
+ Interactive Plotly figure with:
107
+ - Horizontal bars sorted by interaction strength
108
+ - Continuous color gradient indicating strength
109
+ - Hover info showing exact values
110
+ - Responsive layout for different screen sizes
111
+
112
+ Raises
113
+ ------
114
+ ValueError
115
+ If results dict is missing required keys or has invalid structure.
116
+ TypeError
117
+ If parameters have incorrect types.
118
+
119
+ Examples
120
+ --------
121
+ >>> from ml4t.diagnostic.evaluation import compute_shap_interactions
122
+ >>> from ml4t.diagnostic.visualization import plot_interaction_bar
123
+ >>>
124
+ >>> # Compute SHAP interactions
125
+ >>> results = compute_shap_interactions(model, X, top_k=20)
126
+ >>>
127
+ >>> # Plot top 10 interactions
128
+ >>> fig = plot_interaction_bar(results, top_n=10)
129
+ >>> fig.show()
130
+ >>>
131
+ >>> # Custom styling
132
+ >>> fig = plot_interaction_bar(
133
+ ... results,
134
+ ... title="Strong Feature Interactions",
135
+ ... top_n=15,
136
+ ... theme="dark",
137
+ ... color_scheme="plasma",
138
+ ... height=700
139
+ ... )
140
+ >>> fig.write_image("interactions.pdf")
141
+
142
+ Notes
143
+ -----
144
+ - Works with both compute_shap_interactions() and analyze_interactions() results
145
+ - Interaction strength is absolute magnitude (always positive)
146
+ - Pairs are deduplicated (A×B same as B×A)
147
+ - Use top_n to focus on strongest interactions
148
+ """
149
+ # Validate inputs
150
+ theme = validate_theme(theme)
151
+ if top_n is not None:
152
+ validate_positive_int(top_n, "top_n")
153
+
154
+ # Extract interaction pairs - support both result formats
155
+ if "top_interactions" in results:
156
+ # From compute_shap_interactions() or single method
157
+ interactions = results["top_interactions"]
158
+ elif "consensus_ranking" in results:
159
+ # From analyze_interactions()
160
+ interactions = [
161
+ (pair[0], pair[1], pair[2]) # Extract first 3 elements
162
+ for pair in results["consensus_ranking"]
163
+ ]
164
+ else:
165
+ raise ValueError(
166
+ "Results must contain 'top_interactions' (from compute_shap_interactions) "
167
+ "or 'consensus_ranking' (from analyze_interactions)"
168
+ )
169
+
170
+ # Limit to top N
171
+ if top_n is not None:
172
+ interactions = interactions[:top_n]
173
+
174
+ # Create labels and values
175
+ pair_labels = [f"{feat_i} × {feat_j}" for feat_i, feat_j, _ in interactions]
176
+ interaction_values = [abs(val) for _, _, val in interactions]
177
+
178
+ # Reverse for top-to-bottom display
179
+ pair_labels = pair_labels[::-1]
180
+ interaction_values = interaction_values[::-1]
181
+
182
+ # Get theme and colors
183
+ theme_config = get_theme_config(theme)
184
+ colors = get_colorscale(color_scheme or "viridis")
185
+
186
+ # Create figure
187
+ fig = go.Figure()
188
+
189
+ fig.add_trace(
190
+ go.Bar(
191
+ x=interaction_values,
192
+ y=pair_labels,
193
+ orientation="h",
194
+ marker={
195
+ "color": interaction_values,
196
+ "colorscale": colors,
197
+ "showscale": True,
198
+ "colorbar": {
199
+ "title": "Strength",
200
+ "tickformat": ".3f",
201
+ },
202
+ },
203
+ text=[f"{v:.3f}" for v in interaction_values] if show_values else None,
204
+ textposition="outside",
205
+ hovertemplate="<b>%{y}</b><br>Interaction: %{x:.4f}<extra></extra>",
206
+ )
207
+ )
208
+
209
+ # Update layout
210
+ fig.update_layout(
211
+ title=title or "Feature Interactions - Top Pairs",
212
+ xaxis_title="Interaction Strength",
213
+ yaxis_title="Feature Pairs",
214
+ **theme_config["layout"],
215
+ width=width or 1000,
216
+ height=height or max(400, len(pair_labels) * 25 + 100),
217
+ showlegend=False,
218
+ )
219
+
220
+ # Apply responsive layout
221
+ apply_responsive_layout(fig)
222
+
223
+ return fig
224
+
225
+
226
+ def plot_interaction_heatmap(
227
+ results: dict[str, Any],
228
+ *,
229
+ title: str | None = None,
230
+ theme: str | None = None,
231
+ color_scheme: str | None = None,
232
+ width: int | None = None,
233
+ height: int | None = None,
234
+ show_values: bool = False, # False by default - can be crowded
235
+ ) -> go.Figure:
236
+ """Plot heatmap of feature interaction matrix.
237
+
238
+ Creates a symmetric heatmap showing pairwise feature interactions. The matrix
239
+ is symmetric (interaction(i,j) = interaction(j,i)). Diagonal elements represent
240
+ main effects (feature importance without interactions).
241
+
242
+ Parameters
243
+ ----------
244
+ results : dict[str, Any]
245
+ Results from compute_shap_interactions() or similar containing:
246
+ - "interaction_matrix": np.ndarray - (n_features, n_features) matrix
247
+ - "feature_names": list[str] - Feature names for axis labels
248
+ title : str | None, optional
249
+ Plot title. If None, uses "Feature Interaction Matrix"
250
+ theme : str | None, optional
251
+ Theme name ("default", "dark", "print", "presentation").
252
+ If None, uses current global theme.
253
+ color_scheme : str | None, optional
254
+ Color scheme for heatmap. If None, uses "viridis".
255
+ Recommended: "viridis", "plasma", "inferno", "magma", "cividis"
256
+ width : int | None, optional
257
+ Figure width in pixels. If None, uses 800.
258
+ height : int | None, optional
259
+ Figure height in pixels. If None, uses 800.
260
+ show_values : bool, optional
261
+ Whether to show interaction values in cells. Default is False
262
+ (can be crowded for many features).
263
+
264
+ Returns
265
+ -------
266
+ go.Figure
267
+ Interactive Plotly heatmap with:
268
+ - Symmetric interaction matrix
269
+ - Continuous colorscale from weak to strong
270
+ - Optional cell annotations
271
+ - Hover showing feature pairs and values
272
+
273
+ Raises
274
+ ------
275
+ ValueError
276
+ If results dict is missing required keys or has invalid structure.
277
+ TypeError
278
+ If parameters have incorrect types.
279
+
280
+ Examples
281
+ --------
282
+ >>> from ml4t.diagnostic.evaluation import compute_shap_interactions
283
+ >>> from ml4t.diagnostic.visualization import plot_interaction_heatmap
284
+ >>>
285
+ >>> # Compute interactions
286
+ >>> results = compute_shap_interactions(model, X)
287
+ >>>
288
+ >>> # Create heatmap
289
+ >>> fig = plot_interaction_heatmap(results)
290
+ >>> fig.show()
291
+ >>>
292
+ >>> # With annotations for small feature sets
293
+ >>> fig = plot_interaction_heatmap(
294
+ ... results,
295
+ ... show_values=True, # Show numbers in cells
296
+ ... theme="print",
297
+ ... color_scheme="viridis"
298
+ ... )
299
+
300
+ Notes
301
+ -----
302
+ - Matrix is symmetric: interaction(i,j) = interaction(j,i)
303
+ - Diagonal elements are main effects (not interactions)
304
+ - Off-diagonal elements are pairwise interactions
305
+ - For many features (>20), consider hiding cell values (show_values=False)
306
+ - All values are absolute (non-negative)
307
+ """
308
+ # Validate inputs
309
+ validate_plot_results(
310
+ results,
311
+ required_keys=["interaction_matrix", "feature_names"],
312
+ function_name="plot_interaction_heatmap",
313
+ )
314
+ theme = validate_theme(theme)
315
+
316
+ # Extract data
317
+ interaction_matrix = results["interaction_matrix"]
318
+ feature_names = results["feature_names"]
319
+
320
+ # Get theme and colors
321
+ theme_config = get_theme_config(theme)
322
+ colors = get_colorscale(color_scheme or "viridis")
323
+
324
+ # Create hover text
325
+ n_features = len(feature_names)
326
+ hover_text = []
327
+ for i in range(n_features):
328
+ row = []
329
+ for j in range(n_features):
330
+ value = interaction_matrix[i, j]
331
+ if i == j:
332
+ row.append(f"<b>{feature_names[i]}</b><br>Main Effect: {value:.4f}")
333
+ else:
334
+ row.append(
335
+ f"<b>{feature_names[i]}</b> × <b>{feature_names[j]}</b><br>Interaction: {value:.4f}"
336
+ )
337
+ hover_text.append(row)
338
+
339
+ # Create figure
340
+ fig = go.Figure()
341
+
342
+ fig.add_trace(
343
+ go.Heatmap(
344
+ z=interaction_matrix,
345
+ x=feature_names,
346
+ y=feature_names,
347
+ colorscale=colors,
348
+ colorbar={
349
+ "title": "Strength",
350
+ "tickformat": ".3f",
351
+ },
352
+ text=np.round(interaction_matrix, 3) if show_values else None,
353
+ texttemplate="%{text}" if show_values else None,
354
+ textfont={"size": 10},
355
+ hovertext=hover_text,
356
+ hovertemplate="%{hovertext}<extra></extra>",
357
+ )
358
+ )
359
+
360
+ # Update layout
361
+ fig.update_layout(
362
+ title=title or "Feature Interaction Matrix",
363
+ xaxis={
364
+ "title": "",
365
+ "side": "bottom",
366
+ "tickangle": -45 if len(feature_names) > 10 else 0,
367
+ },
368
+ yaxis={
369
+ "title": "",
370
+ "autorange": "reversed", # Top to bottom
371
+ },
372
+ **theme_config["layout"],
373
+ width=width or 800,
374
+ height=height or 800,
375
+ )
376
+
377
+ # Apply responsive layout
378
+ apply_responsive_layout(fig)
379
+
380
+ return fig
381
+
382
+
383
+ def plot_interaction_network(
384
+ results: dict[str, Any],
385
+ *,
386
+ title: str | None = None,
387
+ threshold: float | None = None,
388
+ top_n: int | None = None,
389
+ theme: str | None = None,
390
+ color_scheme: str | None = None,
391
+ width: int | None = None,
392
+ height: int | None = None,
393
+ node_size: int = 30,
394
+ show_edge_labels: bool = False,
395
+ ) -> go.Figure:
396
+ """Plot network graph of feature interactions.
397
+
398
+ Creates an interactive network visualization where:
399
+ - Nodes represent features
400
+ - Edges represent interactions
401
+ - Edge thickness indicates interaction strength
402
+ - Only significant interactions above threshold are shown
403
+
404
+ Parameters
405
+ ----------
406
+ results : dict[str, Any]
407
+ Results from compute_shap_interactions() or analyze_interactions() containing:
408
+ - "top_interactions": list[tuple[str, str, float]] - Feature pairs
409
+ OR
410
+ - "interaction_matrix" and "feature_names" - Will extract top interactions
411
+ title : str | None, optional
412
+ Plot title. If None, uses "Feature Interaction Network"
413
+ threshold : float | None, optional
414
+ Minimum interaction strength to display. If None, uses adaptive threshold
415
+ (median of all interactions or top 20%, whichever is stricter).
416
+ top_n : int | None, optional
417
+ Maximum number of interactions to display. If None, shows all above threshold.
418
+ Useful to avoid cluttered networks.
419
+ theme : str | None, optional
420
+ Theme name ("default", "dark", "print", "presentation").
421
+ If None, uses current global theme.
422
+ color_scheme : str | None, optional
423
+ Color scheme for nodes. If None, uses "set2".
424
+ Recommended: "set2", "set3", "pastel", "bold"
425
+ width : int | None, optional
426
+ Figure width in pixels. If None, uses 1000.
427
+ height : int | None, optional
428
+ Figure height in pixels. If None, uses 800.
429
+ node_size : int, optional
430
+ Size of nodes in pixels. Default is 30.
431
+ show_edge_labels : bool, optional
432
+ Whether to show interaction values on edges. Default is False
433
+ (can be cluttered).
434
+
435
+ Returns
436
+ -------
437
+ go.Figure
438
+ Interactive Plotly network graph with:
439
+ - Nodes positioned using force-directed layout
440
+ - Edge thickness proportional to interaction strength
441
+ - Optional edge labels showing values
442
+ - Hover info for nodes and edges
443
+ - Pan/zoom capability
444
+
445
+ Raises
446
+ ------
447
+ ValueError
448
+ If results dict is missing required keys or has invalid structure.
449
+ If no interactions remain after filtering.
450
+ TypeError
451
+ If parameters have incorrect types.
452
+
453
+ Examples
454
+ --------
455
+ >>> from ml4t.diagnostic.evaluation import compute_shap_interactions
456
+ >>> from ml4t.diagnostic.visualization import plot_interaction_network
457
+ >>>
458
+ >>> # Compute interactions
459
+ >>> results = compute_shap_interactions(model, X, top_k=30)
460
+ >>>
461
+ >>> # Create network showing only strong interactions
462
+ >>> fig = plot_interaction_network(
463
+ ... results,
464
+ ... threshold=0.05, # Show only interactions > 0.05
465
+ ... top_n=20 # Limit to top 20
466
+ ... )
467
+ >>> fig.show()
468
+ >>>
469
+ >>> # Show edge labels
470
+ >>> fig = plot_interaction_network(
471
+ ... results,
472
+ ... show_edge_labels=True,
473
+ ... theme="dark"
474
+ ... )
475
+
476
+ Notes
477
+ -----
478
+ - Network layout uses spring/force-directed algorithm
479
+ - Isolated nodes (no interactions) are excluded
480
+ - Edge thickness is proportional to interaction strength
481
+ - For complex networks (>50 edges), consider increasing threshold or using top_n
482
+ - Use threshold and top_n together for best control
483
+ """
484
+ # Validate inputs
485
+ theme = validate_theme(theme)
486
+ if top_n is not None:
487
+ validate_positive_int(top_n, "top_n")
488
+
489
+ # Extract interactions
490
+ if "top_interactions" in results:
491
+ interactions = results["top_interactions"]
492
+ elif "interaction_matrix" in results and "feature_names" in results:
493
+ # Convert matrix to interaction list
494
+ matrix = results["interaction_matrix"]
495
+ feature_names = results["feature_names"]
496
+ n_features = len(feature_names)
497
+
498
+ interactions = []
499
+ for i in range(n_features):
500
+ for j in range(i + 1, n_features): # Upper triangle only
501
+ interactions.append((feature_names[i], feature_names[j], matrix[i, j]))
502
+
503
+ # Sort by strength
504
+ interactions.sort(key=lambda x: abs(x[2]), reverse=True)
505
+ else:
506
+ raise ValueError(
507
+ "Results must contain 'top_interactions' or 'interaction_matrix' + 'feature_names'"
508
+ )
509
+
510
+ # Apply threshold
511
+ if threshold is None:
512
+ # Adaptive threshold: median or top 20%
513
+ values = [abs(val) for _, _, val in interactions]
514
+ median_threshold = np.median(values)
515
+ percentile_threshold = np.percentile(values, 80)
516
+ threshold = max(median_threshold, percentile_threshold)
517
+
518
+ filtered_interactions = [(f1, f2, val) for f1, f2, val in interactions if abs(val) >= threshold]
519
+
520
+ # Apply top_n limit
521
+ if top_n is not None:
522
+ filtered_interactions = filtered_interactions[:top_n]
523
+
524
+ if len(filtered_interactions) == 0:
525
+ raise ValueError(
526
+ f"No interactions above threshold {threshold:.4f}. Try lowering threshold or increasing top_n."
527
+ )
528
+
529
+ # Build node set
530
+ node_set: set[str] = set()
531
+ for f1, f2, _ in filtered_interactions:
532
+ node_set.add(f1)
533
+ node_set.add(f2)
534
+ nodes = sorted(node_set)
535
+ node_indices = {node: i for i, node in enumerate(nodes)}
536
+
537
+ # Simple circular layout for nodes
538
+ n_nodes = len(nodes)
539
+ angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False)
540
+ radius = 1.0
541
+
542
+ node_x = radius * np.cos(angles)
543
+ node_y = radius * np.sin(angles)
544
+
545
+ # Get theme and colors
546
+ get_theme_config(theme)
547
+ node_colors = get_color_scheme(color_scheme or "set2")
548
+
549
+ # Create figure
550
+ fig = go.Figure()
551
+
552
+ # Add edges
553
+ max_interaction = max(abs(val) for _, _, val in filtered_interactions)
554
+ for f1, f2, val in filtered_interactions:
555
+ i1 = node_indices[f1]
556
+ i2 = node_indices[f2]
557
+
558
+ # Edge thickness proportional to interaction strength
559
+ edge_width = 1 + 5 * (abs(val) / max_interaction)
560
+
561
+ fig.add_trace(
562
+ go.Scatter(
563
+ x=[node_x[i1], node_x[i2]],
564
+ y=[node_y[i1], node_y[i2]],
565
+ mode="lines",
566
+ line={"width": edge_width, "color": "rgba(125,125,125,0.5)"},
567
+ hoverinfo="text",
568
+ hovertext=f"{f1} × {f2}<br>Interaction: {abs(val):.4f}",
569
+ showlegend=False,
570
+ )
571
+ )
572
+
573
+ # Optional edge labels
574
+ if show_edge_labels:
575
+ mid_x = (node_x[i1] + node_x[i2]) / 2
576
+ mid_y = (node_y[i1] + node_y[i2]) / 2
577
+ fig.add_annotation(
578
+ x=mid_x,
579
+ y=mid_y,
580
+ text=f"{abs(val):.2f}",
581
+ showarrow=False,
582
+ font={"size": 8},
583
+ )
584
+
585
+ # Add nodes
586
+ fig.add_trace(
587
+ go.Scatter(
588
+ x=node_x,
589
+ y=node_y,
590
+ mode="markers+text",
591
+ marker={
592
+ "size": node_size,
593
+ "color": [node_colors[i % len(node_colors)] for i in range(n_nodes)],
594
+ "line": {"width": 2, "color": "white"},
595
+ },
596
+ text=nodes,
597
+ textposition="top center",
598
+ textfont={"size": 10},
599
+ hoverinfo="text",
600
+ hovertext=nodes,
601
+ showlegend=False,
602
+ )
603
+ )
604
+
605
+ # Update layout (using simpler approach to avoid theme conflicts)
606
+ fig.update_layout(
607
+ title=title or "Feature Interaction Network",
608
+ width=width or 1000,
609
+ height=height or 800,
610
+ xaxis={"showgrid": False, "zeroline": False, "showticklabels": False},
611
+ yaxis={"showgrid": False, "zeroline": False, "showticklabels": False},
612
+ hovermode="closest",
613
+ )
614
+
615
+ # Apply responsive layout
616
+ apply_responsive_layout(fig)
617
+
618
+ return fig
@@ -0,0 +1,41 @@
1
+ """Portfolio visualization module.
2
+
3
+ Plotly-based interactive visualizations for portfolio analysis.
4
+ Replacement for pyfolio's matplotlib-based plots.
5
+ """
6
+
7
+ from .dashboard import create_portfolio_dashboard
8
+ from .drawdown_plots import (
9
+ plot_drawdown_periods,
10
+ plot_drawdown_underwater,
11
+ )
12
+ from .returns_plots import (
13
+ plot_annual_returns_bar,
14
+ plot_cumulative_returns,
15
+ plot_monthly_returns_heatmap,
16
+ plot_returns_distribution,
17
+ plot_rolling_returns,
18
+ )
19
+ from .risk_plots import (
20
+ plot_rolling_beta,
21
+ plot_rolling_sharpe,
22
+ plot_rolling_volatility,
23
+ )
24
+
25
+ __all__ = [
26
+ # Returns
27
+ "plot_cumulative_returns",
28
+ "plot_rolling_returns",
29
+ "plot_annual_returns_bar",
30
+ "plot_monthly_returns_heatmap",
31
+ "plot_returns_distribution",
32
+ # Risk
33
+ "plot_rolling_volatility",
34
+ "plot_rolling_sharpe",
35
+ "plot_rolling_beta",
36
+ # Drawdown
37
+ "plot_drawdown_underwater",
38
+ "plot_drawdown_periods",
39
+ # Dashboard
40
+ "create_portfolio_dashboard",
41
+ ]