ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1060 @@
1
+ """Core plotting utilities for ML4T Diagnostic visualizations.
2
+
3
+ Provides theme management, color schemes, validation helpers, and
4
+ common layout patterns used across all plot functions.
5
+
6
+ This module implements the standards defined in docs/plot_api_standards.md.
7
+ """
8
+
9
+ from typing import Any
10
+
11
+ import plotly.express as px
12
+ import plotly.graph_objects as go
13
+
14
+ # =============================================================================
15
+ # Global Theme State
16
+ # =============================================================================
17
+
18
+ _CURRENT_THEME = "default" # Global theme setting
19
+
20
+
21
+ def set_plot_theme(theme: str) -> None:
22
+ """Set the global plot theme for all subsequent visualizations.
23
+
24
+ Parameters
25
+ ----------
26
+ theme : str
27
+ Theme name: "default", "dark", "print", "presentation"
28
+
29
+ Raises
30
+ ------
31
+ ValueError
32
+ If theme name is not recognized
33
+
34
+ Examples
35
+ --------
36
+ >>> import ml4t.diagnostic
37
+ >>> ml4t-diagnostic.set_plot_theme("dark")
38
+ >>> # All plots now use dark theme
39
+ >>> fig = plot_importance_bar(results)
40
+ """
41
+ global _CURRENT_THEME
42
+
43
+ if theme not in AVAILABLE_THEMES:
44
+ raise ValueError(
45
+ f"Unknown theme '{theme}'. Available themes: {', '.join(AVAILABLE_THEMES.keys())}"
46
+ )
47
+
48
+ _CURRENT_THEME = theme
49
+
50
+
51
+ def get_plot_theme() -> str:
52
+ """Get the current global plot theme.
53
+
54
+ Returns
55
+ -------
56
+ str
57
+ Current theme name
58
+
59
+ Examples
60
+ --------
61
+ >>> import ml4t.diagnostic
62
+ >>> ml4t-diagnostic.get_plot_theme()
63
+ 'default'
64
+ """
65
+ return _CURRENT_THEME
66
+
67
+
68
+ # =============================================================================
69
+ # Theme Definitions
70
+ # =============================================================================
71
+
72
+ THEME_DEFAULT = {
73
+ "name": "default",
74
+ "description": "Clean, modern light theme for general use",
75
+ "layout": {
76
+ "paper_bgcolor": "#FFFFFF",
77
+ "plot_bgcolor": "#F8F9FA",
78
+ "font": {
79
+ "family": "Inter, -apple-system, system-ui, sans-serif",
80
+ "size": 12,
81
+ "color": "#2C3E50",
82
+ },
83
+ "title_font": {
84
+ "size": 18,
85
+ "color": "#2C3E50",
86
+ "family": "Inter, -apple-system, system-ui, sans-serif",
87
+ },
88
+ "margin": {"l": 80, "r": 20, "t": 100, "b": 80},
89
+ "hovermode": "closest",
90
+ "hoverlabel": {"bgcolor": "white", "font_size": 13, "font_family": "Inter, sans-serif"},
91
+ },
92
+ "colorway": [
93
+ "#3498DB", # Blue
94
+ "#E74C3C", # Red
95
+ "#2ECC71", # Green
96
+ "#F39C12", # Orange
97
+ "#9B59B6", # Purple
98
+ "#1ABC9C", # Teal
99
+ "#E67E22", # Dark orange
100
+ "#95A5A6", # Gray
101
+ ],
102
+ "color_schemes": {
103
+ "sequential": "Blues",
104
+ "diverging": "RdBu",
105
+ "qualitative": "Set2",
106
+ },
107
+ "defaults": {
108
+ "bar_height": 600,
109
+ "heatmap_height": 800,
110
+ "scatter_height": 700,
111
+ "line_height": 500,
112
+ "width": 1000,
113
+ },
114
+ }
115
+
116
+ THEME_DARK = {
117
+ "name": "dark",
118
+ "description": "Dark mode theme for dashboards and presentations",
119
+ "layout": {
120
+ "paper_bgcolor": "#1E1E1E",
121
+ "plot_bgcolor": "#2D2D2D",
122
+ "font": {
123
+ "family": "Inter, -apple-system, system-ui, sans-serif",
124
+ "size": 12,
125
+ "color": "#E0E0E0",
126
+ },
127
+ "title_font": {
128
+ "size": 18,
129
+ "color": "#FFFFFF",
130
+ "family": "Inter, -apple-system, system-ui, sans-serif",
131
+ },
132
+ "margin": {"l": 80, "r": 20, "t": 100, "b": 80},
133
+ "hovermode": "closest",
134
+ "hoverlabel": {
135
+ "bgcolor": "#3D3D3D",
136
+ "font_size": 13,
137
+ "font_family": "Inter, sans-serif",
138
+ "font_color": "#FFFFFF",
139
+ },
140
+ },
141
+ "colorway": [
142
+ "#5DADE2", # Light blue
143
+ "#EC7063", # Light red
144
+ "#58D68D", # Light green
145
+ "#F5B041", # Light orange
146
+ "#AF7AC5", # Light purple
147
+ "#48C9B0", # Light teal
148
+ "#EB984E", # Light dark orange
149
+ "#AAB7B8", # Light gray
150
+ ],
151
+ "color_schemes": {
152
+ "sequential": "Blues",
153
+ "diverging": "RdBu",
154
+ "qualitative": "Set2",
155
+ },
156
+ "defaults": {
157
+ "bar_height": 600,
158
+ "heatmap_height": 800,
159
+ "scatter_height": 700,
160
+ "line_height": 500,
161
+ "width": 1000,
162
+ },
163
+ }
164
+
165
+ THEME_PRINT = {
166
+ "name": "print",
167
+ "description": "Publication-quality grayscale theme",
168
+ "layout": {
169
+ "paper_bgcolor": "#FFFFFF",
170
+ "plot_bgcolor": "#FFFFFF",
171
+ "font": {"family": "Times New Roman, serif", "size": 11, "color": "#000000"},
172
+ "title_font": {"size": 14, "color": "#000000", "family": "Times New Roman, serif"},
173
+ "margin": {"l": 60, "r": 20, "t": 80, "b": 60},
174
+ "hovermode": "closest",
175
+ "hoverlabel": {
176
+ "bgcolor": "white",
177
+ "font_size": 11,
178
+ "font_family": "Times New Roman, serif",
179
+ },
180
+ },
181
+ "colorway": [
182
+ "#000000", # Black
183
+ "#444444", # Dark gray
184
+ "#888888", # Medium gray
185
+ "#BBBBBB", # Light gray
186
+ ],
187
+ "color_schemes": {
188
+ "sequential": "Greys",
189
+ "diverging": "RdGy",
190
+ "qualitative": "Greys",
191
+ },
192
+ "defaults": {
193
+ "bar_height": 500,
194
+ "heatmap_height": 700,
195
+ "scatter_height": 600,
196
+ "line_height": 450,
197
+ "width": 800,
198
+ },
199
+ }
200
+
201
+ THEME_PRESENTATION = {
202
+ "name": "presentation",
203
+ "description": "High-contrast theme for slides and presentations",
204
+ "layout": {
205
+ "paper_bgcolor": "#FFFFFF",
206
+ "plot_bgcolor": "#F0F0F0",
207
+ "font": {
208
+ "family": "Inter, -apple-system, system-ui, sans-serif",
209
+ "size": 16, # Larger fonts
210
+ "color": "#000000",
211
+ },
212
+ "title_font": {
213
+ "size": 24, # Much larger title
214
+ "color": "#000000",
215
+ "family": "Inter, -apple-system, system-ui, sans-serif",
216
+ },
217
+ "margin": {"l": 100, "r": 40, "t": 120, "b": 100},
218
+ "hovermode": "closest",
219
+ "hoverlabel": {"bgcolor": "white", "font_size": 16, "font_family": "Inter, sans-serif"},
220
+ },
221
+ "colorway": [
222
+ "#0066CC", # Strong blue
223
+ "#FF3333", # Strong red
224
+ "#00CC66", # Strong green
225
+ "#FF9900", # Strong orange
226
+ "#9933CC", # Strong purple
227
+ "#00CCCC", # Strong teal
228
+ ],
229
+ "color_schemes": {
230
+ "sequential": "Blues",
231
+ "diverging": "RdBu",
232
+ "qualitative": "Bold",
233
+ },
234
+ "defaults": {
235
+ "bar_height": 700,
236
+ "heatmap_height": 900,
237
+ "scatter_height": 800,
238
+ "line_height": 600,
239
+ "width": 1200,
240
+ },
241
+ }
242
+
243
+ AVAILABLE_THEMES = {
244
+ "default": THEME_DEFAULT,
245
+ "dark": THEME_DARK,
246
+ "print": THEME_PRINT,
247
+ "presentation": THEME_PRESENTATION,
248
+ }
249
+
250
+
251
+ def get_theme_config(theme: str | None = None) -> dict[str, Any]:
252
+ """Get complete theme configuration.
253
+
254
+ Parameters
255
+ ----------
256
+ theme : str | None, default None
257
+ Theme name. If None, uses current global theme
258
+
259
+ Returns
260
+ -------
261
+ dict[str, Any]
262
+ Theme configuration dict with layout, colorway, defaults
263
+
264
+ Raises
265
+ ------
266
+ ValueError
267
+ If theme name is not recognized
268
+
269
+ Examples
270
+ --------
271
+ >>> config = get_theme_config("dark")
272
+ >>> config["layout"]["paper_bgcolor"]
273
+ '#1E1E1E'
274
+ """
275
+ if theme is None:
276
+ theme = get_plot_theme()
277
+
278
+ if theme not in AVAILABLE_THEMES:
279
+ raise ValueError(
280
+ f"Unknown theme '{theme}'. Available themes: {', '.join(AVAILABLE_THEMES.keys())}"
281
+ )
282
+
283
+ return AVAILABLE_THEMES[theme]
284
+
285
+
286
+ # =============================================================================
287
+ # Color Schemes
288
+ # =============================================================================
289
+
290
+ COLOR_SCHEMES = {
291
+ # Sequential (single hue, light to dark)
292
+ "blues": px.colors.sequential.Blues,
293
+ "greens": px.colors.sequential.Greens,
294
+ "reds": px.colors.sequential.Reds,
295
+ "oranges": px.colors.sequential.Oranges,
296
+ "viridis": px.colors.sequential.Viridis,
297
+ "cividis": px.colors.sequential.Cividis,
298
+ "plasma": px.colors.sequential.Plasma,
299
+ # Diverging (two hues with neutral center)
300
+ "rdbu": px.colors.diverging.RdBu,
301
+ "rdylgn": px.colors.diverging.RdYlGn,
302
+ "brbg": px.colors.diverging.BrBG,
303
+ "prgn": px.colors.diverging.PRGn,
304
+ "blues_oranges": ["#0571B0", "#92C5DE", "#F7F7F7", "#F4A582", "#CA0020"],
305
+ # Qualitative (distinct colors for categories)
306
+ "set2": px.colors.qualitative.Set2,
307
+ "set3": px.colors.qualitative.Set3,
308
+ "pastel": px.colors.qualitative.Pastel,
309
+ "dark2": px.colors.qualitative.Dark2,
310
+ "bold": px.colors.qualitative.Bold,
311
+ # Financial
312
+ "gains_losses": ["#FF4444", "#CCCCCC", "#00CC88"], # Red, gray, green
313
+ "quantiles": ["#D32F2F", "#F57C00", "#FBC02D", "#689F38", "#388E3C"],
314
+ # Color-blind safe
315
+ "colorblind_safe": [
316
+ "#0173B2",
317
+ "#DE8F05",
318
+ "#029E73",
319
+ "#CC78BC",
320
+ "#5B4E96",
321
+ "#A65628",
322
+ "#F0E442",
323
+ "#999999",
324
+ ],
325
+ }
326
+
327
+
328
+ def get_color_scheme(name: str) -> list[str]:
329
+ """Get a named color scheme.
330
+
331
+ Parameters
332
+ ----------
333
+ name : str
334
+ Color scheme name (see COLOR_SCHEMES for options)
335
+
336
+ Returns
337
+ -------
338
+ list[str]
339
+ List of hex color codes
340
+
341
+ Raises
342
+ ------
343
+ ValueError
344
+ If color scheme name is not recognized
345
+
346
+ Examples
347
+ --------
348
+ >>> colors = get_color_scheme("viridis")
349
+ >>> len(colors)
350
+ 11
351
+ """
352
+ name = name.lower()
353
+
354
+ if name not in COLOR_SCHEMES:
355
+ raise ValueError(
356
+ f"Unknown color scheme '{name}'. Available: {', '.join(COLOR_SCHEMES.keys())}"
357
+ )
358
+
359
+ return COLOR_SCHEMES[name]
360
+
361
+
362
+ def get_colorscale(
363
+ name: str, n_colors: int | None = None, reverse: bool = False
364
+ ) -> list[str] | list[tuple[float, str]]:
365
+ """Get a color scale for continuous or discrete coloring.
366
+
367
+ Parameters
368
+ ----------
369
+ name : str
370
+ Color scheme name
371
+ n_colors : int | None, default None
372
+ Number of discrete colors. If None, returns continuous colorscale
373
+ reverse : bool, default False
374
+ Reverse the color order
375
+
376
+ Returns
377
+ -------
378
+ list[str] | list[tuple[float, str]]
379
+ Discrete colors (if n_colors specified) or continuous colorscale
380
+
381
+ Examples
382
+ --------
383
+ >>> # Continuous colorscale
384
+ >>> scale = get_colorscale("viridis")
385
+ >>> # Discrete colors
386
+ >>> colors = get_colorscale("viridis", n_colors=5)
387
+ >>> len(colors)
388
+ 5
389
+ """
390
+ colors = get_color_scheme(name)
391
+
392
+ if reverse:
393
+ colors = list(reversed(colors))
394
+
395
+ if n_colors is None:
396
+ # Return as continuous colorscale
397
+ return colors
398
+
399
+ # Sample n_colors from the scheme
400
+ if n_colors <= len(colors):
401
+ # Use evenly spaced colors including both endpoints
402
+ import numpy as np
403
+
404
+ indices = np.linspace(0, len(colors) - 1, n_colors, dtype=int)
405
+ return [colors[i] for i in indices]
406
+ else:
407
+ # Need to interpolate
408
+ import plotly.colors as pc
409
+
410
+ return pc.sample_colorscale(colors, n_colors)
411
+
412
+
413
+ # =============================================================================
414
+ # Validation Helpers
415
+ # =============================================================================
416
+
417
+
418
+ def validate_plot_results(
419
+ results: dict[str, Any], required_keys: list[str], function_name: str
420
+ ) -> None:
421
+ """Validate that results dict has required structure.
422
+
423
+ Parameters
424
+ ----------
425
+ results : dict[str, Any]
426
+ Results dict from analyze_*() function
427
+ required_keys : list[str]
428
+ Keys that must be present in results
429
+ function_name : str
430
+ Name of calling function (for error messages)
431
+
432
+ Raises
433
+ ------
434
+ TypeError
435
+ If results is not a dict
436
+ ValueError
437
+ If required keys are missing
438
+
439
+ Examples
440
+ --------
441
+ >>> validate_plot_results(
442
+ ... results,
443
+ ... required_keys=["consensus_ranking", "method_results"],
444
+ ... function_name="plot_importance_bar"
445
+ ... )
446
+ """
447
+ if not isinstance(results, dict):
448
+ raise TypeError(
449
+ f"{function_name} requires dict from analyze_*() function, got {type(results).__name__}"
450
+ )
451
+
452
+ missing = [k for k in required_keys if k not in results]
453
+ if missing:
454
+ raise ValueError(
455
+ f"Invalid results dict for {function_name}. "
456
+ f"Missing keys: {missing}. "
457
+ f"Expected output from corresponding analyze_*() function."
458
+ )
459
+
460
+
461
+ def validate_positive_int(value: int | None, name: str) -> None:
462
+ """Validate that value is a positive integer.
463
+
464
+ Parameters
465
+ ----------
466
+ value : int | None
467
+ Value to validate
468
+ name : str
469
+ Parameter name (for error messages)
470
+
471
+ Raises
472
+ ------
473
+ ValueError
474
+ If value is not a positive integer
475
+
476
+ Examples
477
+ --------
478
+ >>> validate_positive_int(10, "top_n") # OK
479
+ >>> validate_positive_int(-5, "top_n") # Raises ValueError
480
+ """
481
+ if value is not None and (not isinstance(value, int) or value < 1):
482
+ raise ValueError(f"{name} must be a positive integer, got {value}")
483
+
484
+
485
+ def validate_theme(theme: str | None) -> str:
486
+ """Validate and resolve theme name.
487
+
488
+ Parameters
489
+ ----------
490
+ theme : str | None
491
+ Theme name or None (use global theme)
492
+
493
+ Returns
494
+ -------
495
+ str
496
+ Validated theme name
497
+
498
+ Raises
499
+ ------
500
+ ValueError
501
+ If theme name is not recognized
502
+
503
+ Examples
504
+ --------
505
+ >>> validate_theme("dark")
506
+ 'dark'
507
+ >>> validate_theme(None) # Returns global theme
508
+ 'default'
509
+ """
510
+ if theme is None:
511
+ theme = get_plot_theme()
512
+
513
+ if theme not in AVAILABLE_THEMES:
514
+ raise ValueError(
515
+ f"Unknown theme '{theme}'. Available themes: {', '.join(AVAILABLE_THEMES.keys())}"
516
+ )
517
+
518
+ return theme
519
+
520
+
521
+ def validate_color_scheme(scheme: str | None, theme: str) -> str:
522
+ """Validate and resolve color scheme name.
523
+
524
+ Parameters
525
+ ----------
526
+ scheme : str | None
527
+ Color scheme name or None (use theme default)
528
+ theme : str
529
+ Theme name (for default color scheme)
530
+
531
+ Returns
532
+ -------
533
+ str
534
+ Validated color scheme name
535
+
536
+ Raises
537
+ ------
538
+ ValueError
539
+ If color scheme name is not recognized
540
+
541
+ Examples
542
+ --------
543
+ >>> validate_color_scheme("viridis", "default")
544
+ 'viridis'
545
+ >>> validate_color_scheme(None, "default") # Uses theme default
546
+ 'blues'
547
+ """
548
+ if scheme is None:
549
+ # Use theme's default sequential scheme
550
+ theme_config = get_theme_config(theme)
551
+ scheme = theme_config["color_schemes"]["sequential"]
552
+
553
+ scheme = scheme.lower()
554
+
555
+ if scheme not in COLOR_SCHEMES:
556
+ raise ValueError(
557
+ f"Unknown color scheme '{scheme}'. Available: {', '.join(COLOR_SCHEMES.keys())}"
558
+ )
559
+
560
+ return scheme
561
+
562
+
563
+ # =============================================================================
564
+ # Layout Helpers
565
+ # =============================================================================
566
+
567
+
568
+ def create_base_figure(
569
+ title: str | None = None,
570
+ xaxis_title: str | None = None,
571
+ yaxis_title: str | None = None,
572
+ width: int | None = None,
573
+ height: int | None = None,
574
+ theme: str | None = None,
575
+ margin: dict[str, int] | None = None,
576
+ ) -> go.Figure:
577
+ """Create a base figure with theme applied.
578
+
579
+ Parameters
580
+ ----------
581
+ title : str | None, default None
582
+ Figure title
583
+ xaxis_title : str | None, default None
584
+ X-axis label
585
+ yaxis_title : str | None, default None
586
+ Y-axis label
587
+ width : int | None, default None
588
+ Figure width in pixels
589
+ height : int | None, default None
590
+ Figure height in pixels
591
+ theme : str | None, default None
592
+ Theme name
593
+ margin : dict[str, int] | None, default None
594
+ Margin overrides
595
+
596
+ Returns
597
+ -------
598
+ go.Figure
599
+ Configured Plotly figure
600
+
601
+ Examples
602
+ --------
603
+ >>> fig = create_base_figure(
604
+ ... title="Feature Importance",
605
+ ... xaxis_title="Features",
606
+ ... yaxis_title="Importance Score",
607
+ ... theme="dark"
608
+ ... )
609
+ """
610
+ theme = validate_theme(theme)
611
+ theme_config = get_theme_config(theme)
612
+
613
+ fig = go.Figure()
614
+
615
+ # Build layout
616
+ layout = {
617
+ "title": title,
618
+ "xaxis_title": xaxis_title,
619
+ "yaxis_title": yaxis_title,
620
+ "width": width or theme_config["defaults"]["width"],
621
+ "height": height,
622
+ **theme_config["layout"],
623
+ }
624
+
625
+ if margin is not None:
626
+ layout["margin"] = margin
627
+
628
+ fig.update_layout(layout)
629
+
630
+ return fig
631
+
632
+
633
+ def apply_responsive_layout(fig: go.Figure) -> go.Figure:
634
+ """Make figure responsive (adapts to container size).
635
+
636
+ Parameters
637
+ ----------
638
+ fig : go.Figure
639
+ Figure to make responsive
640
+
641
+ Returns
642
+ -------
643
+ go.Figure
644
+ Modified figure
645
+
646
+ Examples
647
+ --------
648
+ >>> fig = create_base_figure(title="Test")
649
+ >>> fig = apply_responsive_layout(fig)
650
+ """
651
+ fig.update_layout(
652
+ autosize=True,
653
+ margin={"autoexpand": True},
654
+ )
655
+
656
+ return fig
657
+
658
+
659
+ def add_annotation(
660
+ fig: go.Figure,
661
+ text: str,
662
+ x: float,
663
+ y: float,
664
+ xref: str = "paper",
665
+ yref: str = "paper",
666
+ showarrow: bool = False,
667
+ **kwargs,
668
+ ) -> go.Figure:
669
+ """Add text annotation to figure.
670
+
671
+ Parameters
672
+ ----------
673
+ fig : go.Figure
674
+ Figure to annotate
675
+ text : str
676
+ Annotation text
677
+ x : float
678
+ X position (0-1 for paper coordinates)
679
+ y : float
680
+ Y position (0-1 for paper coordinates)
681
+ xref : str, default "paper"
682
+ X reference: "paper" or "x"
683
+ yref : str, default "paper"
684
+ Y reference: "paper" or "y"
685
+ showarrow : bool, default False
686
+ Show arrow pointing to position
687
+ **kwargs
688
+ Additional annotation parameters
689
+
690
+ Returns
691
+ -------
692
+ go.Figure
693
+ Modified figure
694
+
695
+ Examples
696
+ --------
697
+ >>> fig = create_base_figure(title="Test")
698
+ >>> fig = add_annotation(
699
+ ... fig,
700
+ ... text="Key insight here",
701
+ ... x=0.5, y=0.95,
702
+ ... font=dict(size=14, color="red")
703
+ ... )
704
+ """
705
+ fig.add_annotation(text=text, x=x, y=y, xref=xref, yref=yref, showarrow=showarrow, **kwargs)
706
+
707
+ return fig
708
+
709
+
710
+ # =============================================================================
711
+ # Format Helpers
712
+ # =============================================================================
713
+
714
+
715
+ def format_hover_template(
716
+ x_label: str = "x",
717
+ y_label: str = "y",
718
+ x_format: str = "",
719
+ y_format: str = ".3f",
720
+ extra: str = "",
721
+ ) -> str:
722
+ """Create a hover template string.
723
+
724
+ Parameters
725
+ ----------
726
+ x_label : str, default "x"
727
+ Label for x value
728
+ y_label : str, default "y"
729
+ Label for y value
730
+ x_format : str, default ""
731
+ Format string for x value
732
+ y_format : str, default ".3f"
733
+ Format string for y value
734
+ extra : str, default ""
735
+ Extra text to display
736
+
737
+ Returns
738
+ -------
739
+ str
740
+ Plotly hover template string
741
+
742
+ Examples
743
+ --------
744
+ >>> template = format_hover_template(
745
+ ... x_label="Feature",
746
+ ... y_label="Importance",
747
+ ... y_format=".4f"
748
+ ... )
749
+ >>> template
750
+ '<b>%{x}</b><br>Importance: %{y:.4f}<extra></extra>'
751
+ """
752
+ template = f"<b>%{{x{x_format}}}</b><br>{y_label}: %{{y{y_format}}}"
753
+
754
+ if extra:
755
+ template += f"<br>{extra}"
756
+
757
+ template += "<extra></extra>"
758
+
759
+ return template
760
+
761
+
762
+ def format_number(value: float, precision: int = 3) -> str:
763
+ """Format number for display.
764
+
765
+ Parameters
766
+ ----------
767
+ value : float
768
+ Number to format
769
+ precision : int, default 3
770
+ Number of decimal places
771
+
772
+ Returns
773
+ -------
774
+ str
775
+ Formatted string
776
+
777
+ Examples
778
+ --------
779
+ >>> format_number(0.123456, precision=2)
780
+ '0.12'
781
+ >>> format_number(1234567, precision=0)
782
+ '1,234,567'
783
+ """
784
+ if precision == 0:
785
+ return f"{value:,.0f}"
786
+ return f"{value:,.{precision}f}"
787
+
788
+
789
+ def format_percentage(value: float, precision: int = 1) -> str:
790
+ """Format value as percentage.
791
+
792
+ Parameters
793
+ ----------
794
+ value : float
795
+ Value to format (0.05 = 5%)
796
+ precision : int, default 1
797
+ Number of decimal places
798
+
799
+ Returns
800
+ -------
801
+ str
802
+ Formatted percentage string
803
+
804
+ Examples
805
+ --------
806
+ >>> format_percentage(0.05, precision=1)
807
+ '5.0%'
808
+ >>> format_percentage(0.12345, precision=2)
809
+ '12.35%'
810
+ """
811
+ return f"{value * 100:.{precision}f}%"
812
+
813
+
814
+ # =============================================================================
815
+ # Common Plot Elements
816
+ # =============================================================================
817
+
818
+
819
+ def add_threshold_line(
820
+ fig: go.Figure,
821
+ y: float,
822
+ label: str | None = None,
823
+ color: str = "gray",
824
+ dash: str = "dash",
825
+ line_width: float = 1,
826
+ opacity: float = 0.8,
827
+ row: int | None = None,
828
+ col: int | None = None,
829
+ annotation_position: str = "right",
830
+ ) -> go.Figure:
831
+ """Add a horizontal threshold line to a figure.
832
+
833
+ Parameters
834
+ ----------
835
+ fig : go.Figure
836
+ Figure to modify
837
+ y : float
838
+ Y-axis value for the line
839
+ label : str | None, default None
840
+ Optional label/annotation for the line
841
+ color : str, default "gray"
842
+ Line color
843
+ dash : str, default "dash"
844
+ Line style: "solid", "dot", "dash", "longdash", "dashdot"
845
+ line_width : float, default 1
846
+ Line width in pixels
847
+ opacity : float, default 0.8
848
+ Line opacity (0-1)
849
+ row : int | None, default None
850
+ Subplot row (for subplots)
851
+ col : int | None, default None
852
+ Subplot column (for subplots)
853
+ annotation_position : str, default "right"
854
+ Label position: "left", "right"
855
+
856
+ Returns
857
+ -------
858
+ go.Figure
859
+ Modified figure
860
+
861
+ Examples
862
+ --------
863
+ >>> fig = create_base_figure(title="Returns")
864
+ >>> fig = add_threshold_line(fig, y=0, label="Zero line")
865
+ >>> fig = add_threshold_line(fig, y=0.05, label="Target", color="green")
866
+ """
867
+ hline_kwargs = {
868
+ "y": y,
869
+ "line_dash": dash,
870
+ "line_color": color,
871
+ "line_width": line_width,
872
+ "opacity": opacity,
873
+ }
874
+
875
+ if row is not None:
876
+ hline_kwargs["row"] = row
877
+ if col is not None:
878
+ hline_kwargs["col"] = col
879
+
880
+ fig.add_hline(**hline_kwargs)
881
+
882
+ if label:
883
+ x_pos = 0.98 if annotation_position == "right" else 0.02
884
+ xanchor = "right" if annotation_position == "right" else "left"
885
+ fig.add_annotation(
886
+ text=label,
887
+ x=x_pos,
888
+ y=y,
889
+ xref="paper",
890
+ yref="y",
891
+ showarrow=False,
892
+ font={"size": 10, "color": color},
893
+ xanchor=xanchor,
894
+ yanchor="bottom",
895
+ )
896
+
897
+ return fig
898
+
899
+
900
+ def add_confidence_band(
901
+ fig: go.Figure,
902
+ x: list | Any,
903
+ y_lower: list | Any,
904
+ y_upper: list | Any,
905
+ color: str = "blue",
906
+ opacity: float = 0.2,
907
+ name: str = "CI",
908
+ showlegend: bool = False,
909
+ ) -> go.Figure:
910
+ """Add a shaded confidence band to a figure.
911
+
912
+ Creates a filled area between y_lower and y_upper bounds.
913
+
914
+ Parameters
915
+ ----------
916
+ fig : go.Figure
917
+ Figure to modify
918
+ x : array-like
919
+ X-axis values
920
+ y_lower : array-like
921
+ Lower bound values
922
+ y_upper : array-like
923
+ Upper bound values
924
+ color : str, default "blue"
925
+ Fill color (name or hex)
926
+ opacity : float, default 0.2
927
+ Fill opacity (0-1)
928
+ name : str, default "CI"
929
+ Legend name
930
+ showlegend : bool, default False
931
+ Show in legend
932
+
933
+ Returns
934
+ -------
935
+ go.Figure
936
+ Modified figure
937
+
938
+ Examples
939
+ --------
940
+ >>> import numpy as np
941
+ >>> x = np.arange(100)
942
+ >>> y_mean = np.sin(x / 10)
943
+ >>> y_lower = y_mean - 0.2
944
+ >>> y_upper = y_mean + 0.2
945
+ >>> fig = create_base_figure(title="Signal with CI")
946
+ >>> fig = add_confidence_band(fig, x, y_lower, y_upper, color="#3498DB")
947
+ """
948
+ import numpy as np
949
+
950
+ # Convert to lists if needed
951
+ x = list(x) if hasattr(x, "__iter__") and not isinstance(x, str | list) else x
952
+ y_lower = (
953
+ list(y_lower)
954
+ if hasattr(y_lower, "__iter__") and not isinstance(y_lower, str | list)
955
+ else y_lower
956
+ )
957
+ y_upper = (
958
+ list(y_upper)
959
+ if hasattr(y_upper, "__iter__") and not isinstance(y_upper, str | list)
960
+ else y_upper
961
+ )
962
+
963
+ # Convert named color to rgba
964
+ if color.startswith("#"):
965
+ r = int(color[1:3], 16)
966
+ g = int(color[3:5], 16)
967
+ b = int(color[5:7], 16)
968
+ fillcolor = f"rgba({r}, {g}, {b}, {opacity})"
969
+ elif color.startswith("rgb"):
970
+ # Already rgb format, add alpha
971
+ fillcolor = color.replace("rgb", "rgba").replace(")", f", {opacity})")
972
+ else:
973
+ # Named color - use a default mapping
974
+ color_map = {
975
+ "blue": (52, 152, 219),
976
+ "red": (231, 76, 60),
977
+ "green": (46, 204, 113),
978
+ "orange": (243, 156, 18),
979
+ "purple": (155, 89, 182),
980
+ "gray": (128, 128, 128),
981
+ }
982
+ rgb = color_map.get(color.lower(), (128, 128, 128))
983
+ fillcolor = f"rgba({rgb[0]}, {rgb[1]}, {rgb[2]}, {opacity})"
984
+
985
+ # Create the band using fill between traces
986
+ fig.add_trace(
987
+ go.Scatter(
988
+ x=np.concatenate([x, x[::-1]]),
989
+ y=np.concatenate([y_upper, y_lower[::-1]]),
990
+ fill="toself",
991
+ fillcolor=fillcolor,
992
+ line={"color": "rgba(0,0,0,0)"}, # Invisible line
993
+ hoverinfo="skip",
994
+ showlegend=showlegend,
995
+ name=name,
996
+ )
997
+ )
998
+
999
+ return fig
1000
+
1001
+
1002
+ # =============================================================================
1003
+ # Error Message Helpers
1004
+ # =============================================================================
1005
+
1006
+
1007
+ def require_plotly() -> None:
1008
+ """Check that Plotly is installed, raise helpful error if not.
1009
+
1010
+ Raises
1011
+ ------
1012
+ ImportError
1013
+ If Plotly is not installed
1014
+
1015
+ Examples
1016
+ --------
1017
+ >>> require_plotly() # OK if plotly installed
1018
+ """
1019
+ try:
1020
+ import plotly.graph_objects as go # noqa: F401 (availability check)
1021
+ except ImportError:
1022
+ raise ImportError( # noqa: B904
1023
+ "Plotly is required for visualization. Install with:\n"
1024
+ " pip install plotly\n"
1025
+ "Or install ML4T Diagnostic with viz extras:\n"
1026
+ " pip install ml4t-diagnostic[viz]"
1027
+ )
1028
+
1029
+
1030
+ def require_kaleido() -> None:
1031
+ """Check that kaleido is installed (for image export).
1032
+
1033
+ Raises
1034
+ ------
1035
+ ImportError
1036
+ If kaleido is not installed
1037
+
1038
+ Examples
1039
+ --------
1040
+ >>> require_kaleido() # OK if kaleido installed
1041
+ """
1042
+ try:
1043
+ import kaleido # noqa: F401 (availability check)
1044
+ except ImportError:
1045
+ raise ImportError( # noqa: B904
1046
+ "Kaleido is required for image export. Install with:\n"
1047
+ " pip install kaleido\n"
1048
+ "Or install ML4T Diagnostic with viz extras:\n"
1049
+ " pip install ml4t-diagnostic[viz]"
1050
+ )
1051
+
1052
+
1053
+ # Fix: Import plotly.express for color schemes
1054
+ try:
1055
+ import plotly.express as px
1056
+ except ImportError:
1057
+ # Plotly should be available if this module is imported
1058
+ raise ImportError( # noqa: B904
1059
+ "Plotly is required for visualization. Install with:\n pip install plotly"
1060
+ )