ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1060 @@
|
|
|
1
|
+
"""Core plotting utilities for ML4T Diagnostic visualizations.
|
|
2
|
+
|
|
3
|
+
Provides theme management, color schemes, validation helpers, and
|
|
4
|
+
common layout patterns used across all plot functions.
|
|
5
|
+
|
|
6
|
+
This module implements the standards defined in docs/plot_api_standards.md.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import plotly.express as px
|
|
12
|
+
import plotly.graph_objects as go
|
|
13
|
+
|
|
14
|
+
# =============================================================================
|
|
15
|
+
# Global Theme State
|
|
16
|
+
# =============================================================================
|
|
17
|
+
|
|
18
|
+
_CURRENT_THEME = "default" # Global theme setting
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def set_plot_theme(theme: str) -> None:
|
|
22
|
+
"""Set the global plot theme for all subsequent visualizations.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
theme : str
|
|
27
|
+
Theme name: "default", "dark", "print", "presentation"
|
|
28
|
+
|
|
29
|
+
Raises
|
|
30
|
+
------
|
|
31
|
+
ValueError
|
|
32
|
+
If theme name is not recognized
|
|
33
|
+
|
|
34
|
+
Examples
|
|
35
|
+
--------
|
|
36
|
+
>>> import ml4t.diagnostic
|
|
37
|
+
>>> ml4t-diagnostic.set_plot_theme("dark")
|
|
38
|
+
>>> # All plots now use dark theme
|
|
39
|
+
>>> fig = plot_importance_bar(results)
|
|
40
|
+
"""
|
|
41
|
+
global _CURRENT_THEME
|
|
42
|
+
|
|
43
|
+
if theme not in AVAILABLE_THEMES:
|
|
44
|
+
raise ValueError(
|
|
45
|
+
f"Unknown theme '{theme}'. Available themes: {', '.join(AVAILABLE_THEMES.keys())}"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
_CURRENT_THEME = theme
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_plot_theme() -> str:
|
|
52
|
+
"""Get the current global plot theme.
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
str
|
|
57
|
+
Current theme name
|
|
58
|
+
|
|
59
|
+
Examples
|
|
60
|
+
--------
|
|
61
|
+
>>> import ml4t.diagnostic
|
|
62
|
+
>>> ml4t-diagnostic.get_plot_theme()
|
|
63
|
+
'default'
|
|
64
|
+
"""
|
|
65
|
+
return _CURRENT_THEME
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# =============================================================================
|
|
69
|
+
# Theme Definitions
|
|
70
|
+
# =============================================================================
|
|
71
|
+
|
|
72
|
+
THEME_DEFAULT = {
|
|
73
|
+
"name": "default",
|
|
74
|
+
"description": "Clean, modern light theme for general use",
|
|
75
|
+
"layout": {
|
|
76
|
+
"paper_bgcolor": "#FFFFFF",
|
|
77
|
+
"plot_bgcolor": "#F8F9FA",
|
|
78
|
+
"font": {
|
|
79
|
+
"family": "Inter, -apple-system, system-ui, sans-serif",
|
|
80
|
+
"size": 12,
|
|
81
|
+
"color": "#2C3E50",
|
|
82
|
+
},
|
|
83
|
+
"title_font": {
|
|
84
|
+
"size": 18,
|
|
85
|
+
"color": "#2C3E50",
|
|
86
|
+
"family": "Inter, -apple-system, system-ui, sans-serif",
|
|
87
|
+
},
|
|
88
|
+
"margin": {"l": 80, "r": 20, "t": 100, "b": 80},
|
|
89
|
+
"hovermode": "closest",
|
|
90
|
+
"hoverlabel": {"bgcolor": "white", "font_size": 13, "font_family": "Inter, sans-serif"},
|
|
91
|
+
},
|
|
92
|
+
"colorway": [
|
|
93
|
+
"#3498DB", # Blue
|
|
94
|
+
"#E74C3C", # Red
|
|
95
|
+
"#2ECC71", # Green
|
|
96
|
+
"#F39C12", # Orange
|
|
97
|
+
"#9B59B6", # Purple
|
|
98
|
+
"#1ABC9C", # Teal
|
|
99
|
+
"#E67E22", # Dark orange
|
|
100
|
+
"#95A5A6", # Gray
|
|
101
|
+
],
|
|
102
|
+
"color_schemes": {
|
|
103
|
+
"sequential": "Blues",
|
|
104
|
+
"diverging": "RdBu",
|
|
105
|
+
"qualitative": "Set2",
|
|
106
|
+
},
|
|
107
|
+
"defaults": {
|
|
108
|
+
"bar_height": 600,
|
|
109
|
+
"heatmap_height": 800,
|
|
110
|
+
"scatter_height": 700,
|
|
111
|
+
"line_height": 500,
|
|
112
|
+
"width": 1000,
|
|
113
|
+
},
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
THEME_DARK = {
|
|
117
|
+
"name": "dark",
|
|
118
|
+
"description": "Dark mode theme for dashboards and presentations",
|
|
119
|
+
"layout": {
|
|
120
|
+
"paper_bgcolor": "#1E1E1E",
|
|
121
|
+
"plot_bgcolor": "#2D2D2D",
|
|
122
|
+
"font": {
|
|
123
|
+
"family": "Inter, -apple-system, system-ui, sans-serif",
|
|
124
|
+
"size": 12,
|
|
125
|
+
"color": "#E0E0E0",
|
|
126
|
+
},
|
|
127
|
+
"title_font": {
|
|
128
|
+
"size": 18,
|
|
129
|
+
"color": "#FFFFFF",
|
|
130
|
+
"family": "Inter, -apple-system, system-ui, sans-serif",
|
|
131
|
+
},
|
|
132
|
+
"margin": {"l": 80, "r": 20, "t": 100, "b": 80},
|
|
133
|
+
"hovermode": "closest",
|
|
134
|
+
"hoverlabel": {
|
|
135
|
+
"bgcolor": "#3D3D3D",
|
|
136
|
+
"font_size": 13,
|
|
137
|
+
"font_family": "Inter, sans-serif",
|
|
138
|
+
"font_color": "#FFFFFF",
|
|
139
|
+
},
|
|
140
|
+
},
|
|
141
|
+
"colorway": [
|
|
142
|
+
"#5DADE2", # Light blue
|
|
143
|
+
"#EC7063", # Light red
|
|
144
|
+
"#58D68D", # Light green
|
|
145
|
+
"#F5B041", # Light orange
|
|
146
|
+
"#AF7AC5", # Light purple
|
|
147
|
+
"#48C9B0", # Light teal
|
|
148
|
+
"#EB984E", # Light dark orange
|
|
149
|
+
"#AAB7B8", # Light gray
|
|
150
|
+
],
|
|
151
|
+
"color_schemes": {
|
|
152
|
+
"sequential": "Blues",
|
|
153
|
+
"diverging": "RdBu",
|
|
154
|
+
"qualitative": "Set2",
|
|
155
|
+
},
|
|
156
|
+
"defaults": {
|
|
157
|
+
"bar_height": 600,
|
|
158
|
+
"heatmap_height": 800,
|
|
159
|
+
"scatter_height": 700,
|
|
160
|
+
"line_height": 500,
|
|
161
|
+
"width": 1000,
|
|
162
|
+
},
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
THEME_PRINT = {
|
|
166
|
+
"name": "print",
|
|
167
|
+
"description": "Publication-quality grayscale theme",
|
|
168
|
+
"layout": {
|
|
169
|
+
"paper_bgcolor": "#FFFFFF",
|
|
170
|
+
"plot_bgcolor": "#FFFFFF",
|
|
171
|
+
"font": {"family": "Times New Roman, serif", "size": 11, "color": "#000000"},
|
|
172
|
+
"title_font": {"size": 14, "color": "#000000", "family": "Times New Roman, serif"},
|
|
173
|
+
"margin": {"l": 60, "r": 20, "t": 80, "b": 60},
|
|
174
|
+
"hovermode": "closest",
|
|
175
|
+
"hoverlabel": {
|
|
176
|
+
"bgcolor": "white",
|
|
177
|
+
"font_size": 11,
|
|
178
|
+
"font_family": "Times New Roman, serif",
|
|
179
|
+
},
|
|
180
|
+
},
|
|
181
|
+
"colorway": [
|
|
182
|
+
"#000000", # Black
|
|
183
|
+
"#444444", # Dark gray
|
|
184
|
+
"#888888", # Medium gray
|
|
185
|
+
"#BBBBBB", # Light gray
|
|
186
|
+
],
|
|
187
|
+
"color_schemes": {
|
|
188
|
+
"sequential": "Greys",
|
|
189
|
+
"diverging": "RdGy",
|
|
190
|
+
"qualitative": "Greys",
|
|
191
|
+
},
|
|
192
|
+
"defaults": {
|
|
193
|
+
"bar_height": 500,
|
|
194
|
+
"heatmap_height": 700,
|
|
195
|
+
"scatter_height": 600,
|
|
196
|
+
"line_height": 450,
|
|
197
|
+
"width": 800,
|
|
198
|
+
},
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
THEME_PRESENTATION = {
|
|
202
|
+
"name": "presentation",
|
|
203
|
+
"description": "High-contrast theme for slides and presentations",
|
|
204
|
+
"layout": {
|
|
205
|
+
"paper_bgcolor": "#FFFFFF",
|
|
206
|
+
"plot_bgcolor": "#F0F0F0",
|
|
207
|
+
"font": {
|
|
208
|
+
"family": "Inter, -apple-system, system-ui, sans-serif",
|
|
209
|
+
"size": 16, # Larger fonts
|
|
210
|
+
"color": "#000000",
|
|
211
|
+
},
|
|
212
|
+
"title_font": {
|
|
213
|
+
"size": 24, # Much larger title
|
|
214
|
+
"color": "#000000",
|
|
215
|
+
"family": "Inter, -apple-system, system-ui, sans-serif",
|
|
216
|
+
},
|
|
217
|
+
"margin": {"l": 100, "r": 40, "t": 120, "b": 100},
|
|
218
|
+
"hovermode": "closest",
|
|
219
|
+
"hoverlabel": {"bgcolor": "white", "font_size": 16, "font_family": "Inter, sans-serif"},
|
|
220
|
+
},
|
|
221
|
+
"colorway": [
|
|
222
|
+
"#0066CC", # Strong blue
|
|
223
|
+
"#FF3333", # Strong red
|
|
224
|
+
"#00CC66", # Strong green
|
|
225
|
+
"#FF9900", # Strong orange
|
|
226
|
+
"#9933CC", # Strong purple
|
|
227
|
+
"#00CCCC", # Strong teal
|
|
228
|
+
],
|
|
229
|
+
"color_schemes": {
|
|
230
|
+
"sequential": "Blues",
|
|
231
|
+
"diverging": "RdBu",
|
|
232
|
+
"qualitative": "Bold",
|
|
233
|
+
},
|
|
234
|
+
"defaults": {
|
|
235
|
+
"bar_height": 700,
|
|
236
|
+
"heatmap_height": 900,
|
|
237
|
+
"scatter_height": 800,
|
|
238
|
+
"line_height": 600,
|
|
239
|
+
"width": 1200,
|
|
240
|
+
},
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
AVAILABLE_THEMES = {
|
|
244
|
+
"default": THEME_DEFAULT,
|
|
245
|
+
"dark": THEME_DARK,
|
|
246
|
+
"print": THEME_PRINT,
|
|
247
|
+
"presentation": THEME_PRESENTATION,
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def get_theme_config(theme: str | None = None) -> dict[str, Any]:
|
|
252
|
+
"""Get complete theme configuration.
|
|
253
|
+
|
|
254
|
+
Parameters
|
|
255
|
+
----------
|
|
256
|
+
theme : str | None, default None
|
|
257
|
+
Theme name. If None, uses current global theme
|
|
258
|
+
|
|
259
|
+
Returns
|
|
260
|
+
-------
|
|
261
|
+
dict[str, Any]
|
|
262
|
+
Theme configuration dict with layout, colorway, defaults
|
|
263
|
+
|
|
264
|
+
Raises
|
|
265
|
+
------
|
|
266
|
+
ValueError
|
|
267
|
+
If theme name is not recognized
|
|
268
|
+
|
|
269
|
+
Examples
|
|
270
|
+
--------
|
|
271
|
+
>>> config = get_theme_config("dark")
|
|
272
|
+
>>> config["layout"]["paper_bgcolor"]
|
|
273
|
+
'#1E1E1E'
|
|
274
|
+
"""
|
|
275
|
+
if theme is None:
|
|
276
|
+
theme = get_plot_theme()
|
|
277
|
+
|
|
278
|
+
if theme not in AVAILABLE_THEMES:
|
|
279
|
+
raise ValueError(
|
|
280
|
+
f"Unknown theme '{theme}'. Available themes: {', '.join(AVAILABLE_THEMES.keys())}"
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
return AVAILABLE_THEMES[theme]
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
# =============================================================================
|
|
287
|
+
# Color Schemes
|
|
288
|
+
# =============================================================================
|
|
289
|
+
|
|
290
|
+
COLOR_SCHEMES = {
|
|
291
|
+
# Sequential (single hue, light to dark)
|
|
292
|
+
"blues": px.colors.sequential.Blues,
|
|
293
|
+
"greens": px.colors.sequential.Greens,
|
|
294
|
+
"reds": px.colors.sequential.Reds,
|
|
295
|
+
"oranges": px.colors.sequential.Oranges,
|
|
296
|
+
"viridis": px.colors.sequential.Viridis,
|
|
297
|
+
"cividis": px.colors.sequential.Cividis,
|
|
298
|
+
"plasma": px.colors.sequential.Plasma,
|
|
299
|
+
# Diverging (two hues with neutral center)
|
|
300
|
+
"rdbu": px.colors.diverging.RdBu,
|
|
301
|
+
"rdylgn": px.colors.diverging.RdYlGn,
|
|
302
|
+
"brbg": px.colors.diverging.BrBG,
|
|
303
|
+
"prgn": px.colors.diverging.PRGn,
|
|
304
|
+
"blues_oranges": ["#0571B0", "#92C5DE", "#F7F7F7", "#F4A582", "#CA0020"],
|
|
305
|
+
# Qualitative (distinct colors for categories)
|
|
306
|
+
"set2": px.colors.qualitative.Set2,
|
|
307
|
+
"set3": px.colors.qualitative.Set3,
|
|
308
|
+
"pastel": px.colors.qualitative.Pastel,
|
|
309
|
+
"dark2": px.colors.qualitative.Dark2,
|
|
310
|
+
"bold": px.colors.qualitative.Bold,
|
|
311
|
+
# Financial
|
|
312
|
+
"gains_losses": ["#FF4444", "#CCCCCC", "#00CC88"], # Red, gray, green
|
|
313
|
+
"quantiles": ["#D32F2F", "#F57C00", "#FBC02D", "#689F38", "#388E3C"],
|
|
314
|
+
# Color-blind safe
|
|
315
|
+
"colorblind_safe": [
|
|
316
|
+
"#0173B2",
|
|
317
|
+
"#DE8F05",
|
|
318
|
+
"#029E73",
|
|
319
|
+
"#CC78BC",
|
|
320
|
+
"#5B4E96",
|
|
321
|
+
"#A65628",
|
|
322
|
+
"#F0E442",
|
|
323
|
+
"#999999",
|
|
324
|
+
],
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def get_color_scheme(name: str) -> list[str]:
|
|
329
|
+
"""Get a named color scheme.
|
|
330
|
+
|
|
331
|
+
Parameters
|
|
332
|
+
----------
|
|
333
|
+
name : str
|
|
334
|
+
Color scheme name (see COLOR_SCHEMES for options)
|
|
335
|
+
|
|
336
|
+
Returns
|
|
337
|
+
-------
|
|
338
|
+
list[str]
|
|
339
|
+
List of hex color codes
|
|
340
|
+
|
|
341
|
+
Raises
|
|
342
|
+
------
|
|
343
|
+
ValueError
|
|
344
|
+
If color scheme name is not recognized
|
|
345
|
+
|
|
346
|
+
Examples
|
|
347
|
+
--------
|
|
348
|
+
>>> colors = get_color_scheme("viridis")
|
|
349
|
+
>>> len(colors)
|
|
350
|
+
11
|
|
351
|
+
"""
|
|
352
|
+
name = name.lower()
|
|
353
|
+
|
|
354
|
+
if name not in COLOR_SCHEMES:
|
|
355
|
+
raise ValueError(
|
|
356
|
+
f"Unknown color scheme '{name}'. Available: {', '.join(COLOR_SCHEMES.keys())}"
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
return COLOR_SCHEMES[name]
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def get_colorscale(
|
|
363
|
+
name: str, n_colors: int | None = None, reverse: bool = False
|
|
364
|
+
) -> list[str] | list[tuple[float, str]]:
|
|
365
|
+
"""Get a color scale for continuous or discrete coloring.
|
|
366
|
+
|
|
367
|
+
Parameters
|
|
368
|
+
----------
|
|
369
|
+
name : str
|
|
370
|
+
Color scheme name
|
|
371
|
+
n_colors : int | None, default None
|
|
372
|
+
Number of discrete colors. If None, returns continuous colorscale
|
|
373
|
+
reverse : bool, default False
|
|
374
|
+
Reverse the color order
|
|
375
|
+
|
|
376
|
+
Returns
|
|
377
|
+
-------
|
|
378
|
+
list[str] | list[tuple[float, str]]
|
|
379
|
+
Discrete colors (if n_colors specified) or continuous colorscale
|
|
380
|
+
|
|
381
|
+
Examples
|
|
382
|
+
--------
|
|
383
|
+
>>> # Continuous colorscale
|
|
384
|
+
>>> scale = get_colorscale("viridis")
|
|
385
|
+
>>> # Discrete colors
|
|
386
|
+
>>> colors = get_colorscale("viridis", n_colors=5)
|
|
387
|
+
>>> len(colors)
|
|
388
|
+
5
|
|
389
|
+
"""
|
|
390
|
+
colors = get_color_scheme(name)
|
|
391
|
+
|
|
392
|
+
if reverse:
|
|
393
|
+
colors = list(reversed(colors))
|
|
394
|
+
|
|
395
|
+
if n_colors is None:
|
|
396
|
+
# Return as continuous colorscale
|
|
397
|
+
return colors
|
|
398
|
+
|
|
399
|
+
# Sample n_colors from the scheme
|
|
400
|
+
if n_colors <= len(colors):
|
|
401
|
+
# Use evenly spaced colors including both endpoints
|
|
402
|
+
import numpy as np
|
|
403
|
+
|
|
404
|
+
indices = np.linspace(0, len(colors) - 1, n_colors, dtype=int)
|
|
405
|
+
return [colors[i] for i in indices]
|
|
406
|
+
else:
|
|
407
|
+
# Need to interpolate
|
|
408
|
+
import plotly.colors as pc
|
|
409
|
+
|
|
410
|
+
return pc.sample_colorscale(colors, n_colors)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
# =============================================================================
|
|
414
|
+
# Validation Helpers
|
|
415
|
+
# =============================================================================
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def validate_plot_results(
|
|
419
|
+
results: dict[str, Any], required_keys: list[str], function_name: str
|
|
420
|
+
) -> None:
|
|
421
|
+
"""Validate that results dict has required structure.
|
|
422
|
+
|
|
423
|
+
Parameters
|
|
424
|
+
----------
|
|
425
|
+
results : dict[str, Any]
|
|
426
|
+
Results dict from analyze_*() function
|
|
427
|
+
required_keys : list[str]
|
|
428
|
+
Keys that must be present in results
|
|
429
|
+
function_name : str
|
|
430
|
+
Name of calling function (for error messages)
|
|
431
|
+
|
|
432
|
+
Raises
|
|
433
|
+
------
|
|
434
|
+
TypeError
|
|
435
|
+
If results is not a dict
|
|
436
|
+
ValueError
|
|
437
|
+
If required keys are missing
|
|
438
|
+
|
|
439
|
+
Examples
|
|
440
|
+
--------
|
|
441
|
+
>>> validate_plot_results(
|
|
442
|
+
... results,
|
|
443
|
+
... required_keys=["consensus_ranking", "method_results"],
|
|
444
|
+
... function_name="plot_importance_bar"
|
|
445
|
+
... )
|
|
446
|
+
"""
|
|
447
|
+
if not isinstance(results, dict):
|
|
448
|
+
raise TypeError(
|
|
449
|
+
f"{function_name} requires dict from analyze_*() function, got {type(results).__name__}"
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
missing = [k for k in required_keys if k not in results]
|
|
453
|
+
if missing:
|
|
454
|
+
raise ValueError(
|
|
455
|
+
f"Invalid results dict for {function_name}. "
|
|
456
|
+
f"Missing keys: {missing}. "
|
|
457
|
+
f"Expected output from corresponding analyze_*() function."
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def validate_positive_int(value: int | None, name: str) -> None:
|
|
462
|
+
"""Validate that value is a positive integer.
|
|
463
|
+
|
|
464
|
+
Parameters
|
|
465
|
+
----------
|
|
466
|
+
value : int | None
|
|
467
|
+
Value to validate
|
|
468
|
+
name : str
|
|
469
|
+
Parameter name (for error messages)
|
|
470
|
+
|
|
471
|
+
Raises
|
|
472
|
+
------
|
|
473
|
+
ValueError
|
|
474
|
+
If value is not a positive integer
|
|
475
|
+
|
|
476
|
+
Examples
|
|
477
|
+
--------
|
|
478
|
+
>>> validate_positive_int(10, "top_n") # OK
|
|
479
|
+
>>> validate_positive_int(-5, "top_n") # Raises ValueError
|
|
480
|
+
"""
|
|
481
|
+
if value is not None and (not isinstance(value, int) or value < 1):
|
|
482
|
+
raise ValueError(f"{name} must be a positive integer, got {value}")
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def validate_theme(theme: str | None) -> str:
|
|
486
|
+
"""Validate and resolve theme name.
|
|
487
|
+
|
|
488
|
+
Parameters
|
|
489
|
+
----------
|
|
490
|
+
theme : str | None
|
|
491
|
+
Theme name or None (use global theme)
|
|
492
|
+
|
|
493
|
+
Returns
|
|
494
|
+
-------
|
|
495
|
+
str
|
|
496
|
+
Validated theme name
|
|
497
|
+
|
|
498
|
+
Raises
|
|
499
|
+
------
|
|
500
|
+
ValueError
|
|
501
|
+
If theme name is not recognized
|
|
502
|
+
|
|
503
|
+
Examples
|
|
504
|
+
--------
|
|
505
|
+
>>> validate_theme("dark")
|
|
506
|
+
'dark'
|
|
507
|
+
>>> validate_theme(None) # Returns global theme
|
|
508
|
+
'default'
|
|
509
|
+
"""
|
|
510
|
+
if theme is None:
|
|
511
|
+
theme = get_plot_theme()
|
|
512
|
+
|
|
513
|
+
if theme not in AVAILABLE_THEMES:
|
|
514
|
+
raise ValueError(
|
|
515
|
+
f"Unknown theme '{theme}'. Available themes: {', '.join(AVAILABLE_THEMES.keys())}"
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
return theme
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def validate_color_scheme(scheme: str | None, theme: str) -> str:
|
|
522
|
+
"""Validate and resolve color scheme name.
|
|
523
|
+
|
|
524
|
+
Parameters
|
|
525
|
+
----------
|
|
526
|
+
scheme : str | None
|
|
527
|
+
Color scheme name or None (use theme default)
|
|
528
|
+
theme : str
|
|
529
|
+
Theme name (for default color scheme)
|
|
530
|
+
|
|
531
|
+
Returns
|
|
532
|
+
-------
|
|
533
|
+
str
|
|
534
|
+
Validated color scheme name
|
|
535
|
+
|
|
536
|
+
Raises
|
|
537
|
+
------
|
|
538
|
+
ValueError
|
|
539
|
+
If color scheme name is not recognized
|
|
540
|
+
|
|
541
|
+
Examples
|
|
542
|
+
--------
|
|
543
|
+
>>> validate_color_scheme("viridis", "default")
|
|
544
|
+
'viridis'
|
|
545
|
+
>>> validate_color_scheme(None, "default") # Uses theme default
|
|
546
|
+
'blues'
|
|
547
|
+
"""
|
|
548
|
+
if scheme is None:
|
|
549
|
+
# Use theme's default sequential scheme
|
|
550
|
+
theme_config = get_theme_config(theme)
|
|
551
|
+
scheme = theme_config["color_schemes"]["sequential"]
|
|
552
|
+
|
|
553
|
+
scheme = scheme.lower()
|
|
554
|
+
|
|
555
|
+
if scheme not in COLOR_SCHEMES:
|
|
556
|
+
raise ValueError(
|
|
557
|
+
f"Unknown color scheme '{scheme}'. Available: {', '.join(COLOR_SCHEMES.keys())}"
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
return scheme
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
# =============================================================================
|
|
564
|
+
# Layout Helpers
|
|
565
|
+
# =============================================================================
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def create_base_figure(
|
|
569
|
+
title: str | None = None,
|
|
570
|
+
xaxis_title: str | None = None,
|
|
571
|
+
yaxis_title: str | None = None,
|
|
572
|
+
width: int | None = None,
|
|
573
|
+
height: int | None = None,
|
|
574
|
+
theme: str | None = None,
|
|
575
|
+
margin: dict[str, int] | None = None,
|
|
576
|
+
) -> go.Figure:
|
|
577
|
+
"""Create a base figure with theme applied.
|
|
578
|
+
|
|
579
|
+
Parameters
|
|
580
|
+
----------
|
|
581
|
+
title : str | None, default None
|
|
582
|
+
Figure title
|
|
583
|
+
xaxis_title : str | None, default None
|
|
584
|
+
X-axis label
|
|
585
|
+
yaxis_title : str | None, default None
|
|
586
|
+
Y-axis label
|
|
587
|
+
width : int | None, default None
|
|
588
|
+
Figure width in pixels
|
|
589
|
+
height : int | None, default None
|
|
590
|
+
Figure height in pixels
|
|
591
|
+
theme : str | None, default None
|
|
592
|
+
Theme name
|
|
593
|
+
margin : dict[str, int] | None, default None
|
|
594
|
+
Margin overrides
|
|
595
|
+
|
|
596
|
+
Returns
|
|
597
|
+
-------
|
|
598
|
+
go.Figure
|
|
599
|
+
Configured Plotly figure
|
|
600
|
+
|
|
601
|
+
Examples
|
|
602
|
+
--------
|
|
603
|
+
>>> fig = create_base_figure(
|
|
604
|
+
... title="Feature Importance",
|
|
605
|
+
... xaxis_title="Features",
|
|
606
|
+
... yaxis_title="Importance Score",
|
|
607
|
+
... theme="dark"
|
|
608
|
+
... )
|
|
609
|
+
"""
|
|
610
|
+
theme = validate_theme(theme)
|
|
611
|
+
theme_config = get_theme_config(theme)
|
|
612
|
+
|
|
613
|
+
fig = go.Figure()
|
|
614
|
+
|
|
615
|
+
# Build layout
|
|
616
|
+
layout = {
|
|
617
|
+
"title": title,
|
|
618
|
+
"xaxis_title": xaxis_title,
|
|
619
|
+
"yaxis_title": yaxis_title,
|
|
620
|
+
"width": width or theme_config["defaults"]["width"],
|
|
621
|
+
"height": height,
|
|
622
|
+
**theme_config["layout"],
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
if margin is not None:
|
|
626
|
+
layout["margin"] = margin
|
|
627
|
+
|
|
628
|
+
fig.update_layout(layout)
|
|
629
|
+
|
|
630
|
+
return fig
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
def apply_responsive_layout(fig: go.Figure) -> go.Figure:
|
|
634
|
+
"""Make figure responsive (adapts to container size).
|
|
635
|
+
|
|
636
|
+
Parameters
|
|
637
|
+
----------
|
|
638
|
+
fig : go.Figure
|
|
639
|
+
Figure to make responsive
|
|
640
|
+
|
|
641
|
+
Returns
|
|
642
|
+
-------
|
|
643
|
+
go.Figure
|
|
644
|
+
Modified figure
|
|
645
|
+
|
|
646
|
+
Examples
|
|
647
|
+
--------
|
|
648
|
+
>>> fig = create_base_figure(title="Test")
|
|
649
|
+
>>> fig = apply_responsive_layout(fig)
|
|
650
|
+
"""
|
|
651
|
+
fig.update_layout(
|
|
652
|
+
autosize=True,
|
|
653
|
+
margin={"autoexpand": True},
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
return fig
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
def add_annotation(
|
|
660
|
+
fig: go.Figure,
|
|
661
|
+
text: str,
|
|
662
|
+
x: float,
|
|
663
|
+
y: float,
|
|
664
|
+
xref: str = "paper",
|
|
665
|
+
yref: str = "paper",
|
|
666
|
+
showarrow: bool = False,
|
|
667
|
+
**kwargs,
|
|
668
|
+
) -> go.Figure:
|
|
669
|
+
"""Add text annotation to figure.
|
|
670
|
+
|
|
671
|
+
Parameters
|
|
672
|
+
----------
|
|
673
|
+
fig : go.Figure
|
|
674
|
+
Figure to annotate
|
|
675
|
+
text : str
|
|
676
|
+
Annotation text
|
|
677
|
+
x : float
|
|
678
|
+
X position (0-1 for paper coordinates)
|
|
679
|
+
y : float
|
|
680
|
+
Y position (0-1 for paper coordinates)
|
|
681
|
+
xref : str, default "paper"
|
|
682
|
+
X reference: "paper" or "x"
|
|
683
|
+
yref : str, default "paper"
|
|
684
|
+
Y reference: "paper" or "y"
|
|
685
|
+
showarrow : bool, default False
|
|
686
|
+
Show arrow pointing to position
|
|
687
|
+
**kwargs
|
|
688
|
+
Additional annotation parameters
|
|
689
|
+
|
|
690
|
+
Returns
|
|
691
|
+
-------
|
|
692
|
+
go.Figure
|
|
693
|
+
Modified figure
|
|
694
|
+
|
|
695
|
+
Examples
|
|
696
|
+
--------
|
|
697
|
+
>>> fig = create_base_figure(title="Test")
|
|
698
|
+
>>> fig = add_annotation(
|
|
699
|
+
... fig,
|
|
700
|
+
... text="Key insight here",
|
|
701
|
+
... x=0.5, y=0.95,
|
|
702
|
+
... font=dict(size=14, color="red")
|
|
703
|
+
... )
|
|
704
|
+
"""
|
|
705
|
+
fig.add_annotation(text=text, x=x, y=y, xref=xref, yref=yref, showarrow=showarrow, **kwargs)
|
|
706
|
+
|
|
707
|
+
return fig
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
# =============================================================================
|
|
711
|
+
# Format Helpers
|
|
712
|
+
# =============================================================================
|
|
713
|
+
|
|
714
|
+
|
|
715
|
+
def format_hover_template(
|
|
716
|
+
x_label: str = "x",
|
|
717
|
+
y_label: str = "y",
|
|
718
|
+
x_format: str = "",
|
|
719
|
+
y_format: str = ".3f",
|
|
720
|
+
extra: str = "",
|
|
721
|
+
) -> str:
|
|
722
|
+
"""Create a hover template string.
|
|
723
|
+
|
|
724
|
+
Parameters
|
|
725
|
+
----------
|
|
726
|
+
x_label : str, default "x"
|
|
727
|
+
Label for x value
|
|
728
|
+
y_label : str, default "y"
|
|
729
|
+
Label for y value
|
|
730
|
+
x_format : str, default ""
|
|
731
|
+
Format string for x value
|
|
732
|
+
y_format : str, default ".3f"
|
|
733
|
+
Format string for y value
|
|
734
|
+
extra : str, default ""
|
|
735
|
+
Extra text to display
|
|
736
|
+
|
|
737
|
+
Returns
|
|
738
|
+
-------
|
|
739
|
+
str
|
|
740
|
+
Plotly hover template string
|
|
741
|
+
|
|
742
|
+
Examples
|
|
743
|
+
--------
|
|
744
|
+
>>> template = format_hover_template(
|
|
745
|
+
... x_label="Feature",
|
|
746
|
+
... y_label="Importance",
|
|
747
|
+
... y_format=".4f"
|
|
748
|
+
... )
|
|
749
|
+
>>> template
|
|
750
|
+
'<b>%{x}</b><br>Importance: %{y:.4f}<extra></extra>'
|
|
751
|
+
"""
|
|
752
|
+
template = f"<b>%{{x{x_format}}}</b><br>{y_label}: %{{y{y_format}}}"
|
|
753
|
+
|
|
754
|
+
if extra:
|
|
755
|
+
template += f"<br>{extra}"
|
|
756
|
+
|
|
757
|
+
template += "<extra></extra>"
|
|
758
|
+
|
|
759
|
+
return template
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
def format_number(value: float, precision: int = 3) -> str:
|
|
763
|
+
"""Format number for display.
|
|
764
|
+
|
|
765
|
+
Parameters
|
|
766
|
+
----------
|
|
767
|
+
value : float
|
|
768
|
+
Number to format
|
|
769
|
+
precision : int, default 3
|
|
770
|
+
Number of decimal places
|
|
771
|
+
|
|
772
|
+
Returns
|
|
773
|
+
-------
|
|
774
|
+
str
|
|
775
|
+
Formatted string
|
|
776
|
+
|
|
777
|
+
Examples
|
|
778
|
+
--------
|
|
779
|
+
>>> format_number(0.123456, precision=2)
|
|
780
|
+
'0.12'
|
|
781
|
+
>>> format_number(1234567, precision=0)
|
|
782
|
+
'1,234,567'
|
|
783
|
+
"""
|
|
784
|
+
if precision == 0:
|
|
785
|
+
return f"{value:,.0f}"
|
|
786
|
+
return f"{value:,.{precision}f}"
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
def format_percentage(value: float, precision: int = 1) -> str:
|
|
790
|
+
"""Format value as percentage.
|
|
791
|
+
|
|
792
|
+
Parameters
|
|
793
|
+
----------
|
|
794
|
+
value : float
|
|
795
|
+
Value to format (0.05 = 5%)
|
|
796
|
+
precision : int, default 1
|
|
797
|
+
Number of decimal places
|
|
798
|
+
|
|
799
|
+
Returns
|
|
800
|
+
-------
|
|
801
|
+
str
|
|
802
|
+
Formatted percentage string
|
|
803
|
+
|
|
804
|
+
Examples
|
|
805
|
+
--------
|
|
806
|
+
>>> format_percentage(0.05, precision=1)
|
|
807
|
+
'5.0%'
|
|
808
|
+
>>> format_percentage(0.12345, precision=2)
|
|
809
|
+
'12.35%'
|
|
810
|
+
"""
|
|
811
|
+
return f"{value * 100:.{precision}f}%"
|
|
812
|
+
|
|
813
|
+
|
|
814
|
+
# =============================================================================
|
|
815
|
+
# Common Plot Elements
|
|
816
|
+
# =============================================================================
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
def add_threshold_line(
|
|
820
|
+
fig: go.Figure,
|
|
821
|
+
y: float,
|
|
822
|
+
label: str | None = None,
|
|
823
|
+
color: str = "gray",
|
|
824
|
+
dash: str = "dash",
|
|
825
|
+
line_width: float = 1,
|
|
826
|
+
opacity: float = 0.8,
|
|
827
|
+
row: int | None = None,
|
|
828
|
+
col: int | None = None,
|
|
829
|
+
annotation_position: str = "right",
|
|
830
|
+
) -> go.Figure:
|
|
831
|
+
"""Add a horizontal threshold line to a figure.
|
|
832
|
+
|
|
833
|
+
Parameters
|
|
834
|
+
----------
|
|
835
|
+
fig : go.Figure
|
|
836
|
+
Figure to modify
|
|
837
|
+
y : float
|
|
838
|
+
Y-axis value for the line
|
|
839
|
+
label : str | None, default None
|
|
840
|
+
Optional label/annotation for the line
|
|
841
|
+
color : str, default "gray"
|
|
842
|
+
Line color
|
|
843
|
+
dash : str, default "dash"
|
|
844
|
+
Line style: "solid", "dot", "dash", "longdash", "dashdot"
|
|
845
|
+
line_width : float, default 1
|
|
846
|
+
Line width in pixels
|
|
847
|
+
opacity : float, default 0.8
|
|
848
|
+
Line opacity (0-1)
|
|
849
|
+
row : int | None, default None
|
|
850
|
+
Subplot row (for subplots)
|
|
851
|
+
col : int | None, default None
|
|
852
|
+
Subplot column (for subplots)
|
|
853
|
+
annotation_position : str, default "right"
|
|
854
|
+
Label position: "left", "right"
|
|
855
|
+
|
|
856
|
+
Returns
|
|
857
|
+
-------
|
|
858
|
+
go.Figure
|
|
859
|
+
Modified figure
|
|
860
|
+
|
|
861
|
+
Examples
|
|
862
|
+
--------
|
|
863
|
+
>>> fig = create_base_figure(title="Returns")
|
|
864
|
+
>>> fig = add_threshold_line(fig, y=0, label="Zero line")
|
|
865
|
+
>>> fig = add_threshold_line(fig, y=0.05, label="Target", color="green")
|
|
866
|
+
"""
|
|
867
|
+
hline_kwargs = {
|
|
868
|
+
"y": y,
|
|
869
|
+
"line_dash": dash,
|
|
870
|
+
"line_color": color,
|
|
871
|
+
"line_width": line_width,
|
|
872
|
+
"opacity": opacity,
|
|
873
|
+
}
|
|
874
|
+
|
|
875
|
+
if row is not None:
|
|
876
|
+
hline_kwargs["row"] = row
|
|
877
|
+
if col is not None:
|
|
878
|
+
hline_kwargs["col"] = col
|
|
879
|
+
|
|
880
|
+
fig.add_hline(**hline_kwargs)
|
|
881
|
+
|
|
882
|
+
if label:
|
|
883
|
+
x_pos = 0.98 if annotation_position == "right" else 0.02
|
|
884
|
+
xanchor = "right" if annotation_position == "right" else "left"
|
|
885
|
+
fig.add_annotation(
|
|
886
|
+
text=label,
|
|
887
|
+
x=x_pos,
|
|
888
|
+
y=y,
|
|
889
|
+
xref="paper",
|
|
890
|
+
yref="y",
|
|
891
|
+
showarrow=False,
|
|
892
|
+
font={"size": 10, "color": color},
|
|
893
|
+
xanchor=xanchor,
|
|
894
|
+
yanchor="bottom",
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
return fig
|
|
898
|
+
|
|
899
|
+
|
|
900
|
+
def add_confidence_band(
|
|
901
|
+
fig: go.Figure,
|
|
902
|
+
x: list | Any,
|
|
903
|
+
y_lower: list | Any,
|
|
904
|
+
y_upper: list | Any,
|
|
905
|
+
color: str = "blue",
|
|
906
|
+
opacity: float = 0.2,
|
|
907
|
+
name: str = "CI",
|
|
908
|
+
showlegend: bool = False,
|
|
909
|
+
) -> go.Figure:
|
|
910
|
+
"""Add a shaded confidence band to a figure.
|
|
911
|
+
|
|
912
|
+
Creates a filled area between y_lower and y_upper bounds.
|
|
913
|
+
|
|
914
|
+
Parameters
|
|
915
|
+
----------
|
|
916
|
+
fig : go.Figure
|
|
917
|
+
Figure to modify
|
|
918
|
+
x : array-like
|
|
919
|
+
X-axis values
|
|
920
|
+
y_lower : array-like
|
|
921
|
+
Lower bound values
|
|
922
|
+
y_upper : array-like
|
|
923
|
+
Upper bound values
|
|
924
|
+
color : str, default "blue"
|
|
925
|
+
Fill color (name or hex)
|
|
926
|
+
opacity : float, default 0.2
|
|
927
|
+
Fill opacity (0-1)
|
|
928
|
+
name : str, default "CI"
|
|
929
|
+
Legend name
|
|
930
|
+
showlegend : bool, default False
|
|
931
|
+
Show in legend
|
|
932
|
+
|
|
933
|
+
Returns
|
|
934
|
+
-------
|
|
935
|
+
go.Figure
|
|
936
|
+
Modified figure
|
|
937
|
+
|
|
938
|
+
Examples
|
|
939
|
+
--------
|
|
940
|
+
>>> import numpy as np
|
|
941
|
+
>>> x = np.arange(100)
|
|
942
|
+
>>> y_mean = np.sin(x / 10)
|
|
943
|
+
>>> y_lower = y_mean - 0.2
|
|
944
|
+
>>> y_upper = y_mean + 0.2
|
|
945
|
+
>>> fig = create_base_figure(title="Signal with CI")
|
|
946
|
+
>>> fig = add_confidence_band(fig, x, y_lower, y_upper, color="#3498DB")
|
|
947
|
+
"""
|
|
948
|
+
import numpy as np
|
|
949
|
+
|
|
950
|
+
# Convert to lists if needed
|
|
951
|
+
x = list(x) if hasattr(x, "__iter__") and not isinstance(x, str | list) else x
|
|
952
|
+
y_lower = (
|
|
953
|
+
list(y_lower)
|
|
954
|
+
if hasattr(y_lower, "__iter__") and not isinstance(y_lower, str | list)
|
|
955
|
+
else y_lower
|
|
956
|
+
)
|
|
957
|
+
y_upper = (
|
|
958
|
+
list(y_upper)
|
|
959
|
+
if hasattr(y_upper, "__iter__") and not isinstance(y_upper, str | list)
|
|
960
|
+
else y_upper
|
|
961
|
+
)
|
|
962
|
+
|
|
963
|
+
# Convert named color to rgba
|
|
964
|
+
if color.startswith("#"):
|
|
965
|
+
r = int(color[1:3], 16)
|
|
966
|
+
g = int(color[3:5], 16)
|
|
967
|
+
b = int(color[5:7], 16)
|
|
968
|
+
fillcolor = f"rgba({r}, {g}, {b}, {opacity})"
|
|
969
|
+
elif color.startswith("rgb"):
|
|
970
|
+
# Already rgb format, add alpha
|
|
971
|
+
fillcolor = color.replace("rgb", "rgba").replace(")", f", {opacity})")
|
|
972
|
+
else:
|
|
973
|
+
# Named color - use a default mapping
|
|
974
|
+
color_map = {
|
|
975
|
+
"blue": (52, 152, 219),
|
|
976
|
+
"red": (231, 76, 60),
|
|
977
|
+
"green": (46, 204, 113),
|
|
978
|
+
"orange": (243, 156, 18),
|
|
979
|
+
"purple": (155, 89, 182),
|
|
980
|
+
"gray": (128, 128, 128),
|
|
981
|
+
}
|
|
982
|
+
rgb = color_map.get(color.lower(), (128, 128, 128))
|
|
983
|
+
fillcolor = f"rgba({rgb[0]}, {rgb[1]}, {rgb[2]}, {opacity})"
|
|
984
|
+
|
|
985
|
+
# Create the band using fill between traces
|
|
986
|
+
fig.add_trace(
|
|
987
|
+
go.Scatter(
|
|
988
|
+
x=np.concatenate([x, x[::-1]]),
|
|
989
|
+
y=np.concatenate([y_upper, y_lower[::-1]]),
|
|
990
|
+
fill="toself",
|
|
991
|
+
fillcolor=fillcolor,
|
|
992
|
+
line={"color": "rgba(0,0,0,0)"}, # Invisible line
|
|
993
|
+
hoverinfo="skip",
|
|
994
|
+
showlegend=showlegend,
|
|
995
|
+
name=name,
|
|
996
|
+
)
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
return fig
|
|
1000
|
+
|
|
1001
|
+
|
|
1002
|
+
# =============================================================================
|
|
1003
|
+
# Error Message Helpers
|
|
1004
|
+
# =============================================================================
|
|
1005
|
+
|
|
1006
|
+
|
|
1007
|
+
def require_plotly() -> None:
|
|
1008
|
+
"""Check that Plotly is installed, raise helpful error if not.
|
|
1009
|
+
|
|
1010
|
+
Raises
|
|
1011
|
+
------
|
|
1012
|
+
ImportError
|
|
1013
|
+
If Plotly is not installed
|
|
1014
|
+
|
|
1015
|
+
Examples
|
|
1016
|
+
--------
|
|
1017
|
+
>>> require_plotly() # OK if plotly installed
|
|
1018
|
+
"""
|
|
1019
|
+
try:
|
|
1020
|
+
import plotly.graph_objects as go # noqa: F401 (availability check)
|
|
1021
|
+
except ImportError:
|
|
1022
|
+
raise ImportError( # noqa: B904
|
|
1023
|
+
"Plotly is required for visualization. Install with:\n"
|
|
1024
|
+
" pip install plotly\n"
|
|
1025
|
+
"Or install ML4T Diagnostic with viz extras:\n"
|
|
1026
|
+
" pip install ml4t-diagnostic[viz]"
|
|
1027
|
+
)
|
|
1028
|
+
|
|
1029
|
+
|
|
1030
|
+
def require_kaleido() -> None:
|
|
1031
|
+
"""Check that kaleido is installed (for image export).
|
|
1032
|
+
|
|
1033
|
+
Raises
|
|
1034
|
+
------
|
|
1035
|
+
ImportError
|
|
1036
|
+
If kaleido is not installed
|
|
1037
|
+
|
|
1038
|
+
Examples
|
|
1039
|
+
--------
|
|
1040
|
+
>>> require_kaleido() # OK if kaleido installed
|
|
1041
|
+
"""
|
|
1042
|
+
try:
|
|
1043
|
+
import kaleido # noqa: F401 (availability check)
|
|
1044
|
+
except ImportError:
|
|
1045
|
+
raise ImportError( # noqa: B904
|
|
1046
|
+
"Kaleido is required for image export. Install with:\n"
|
|
1047
|
+
" pip install kaleido\n"
|
|
1048
|
+
"Or install ML4T Diagnostic with viz extras:\n"
|
|
1049
|
+
" pip install ml4t-diagnostic[viz]"
|
|
1050
|
+
)
|
|
1051
|
+
|
|
1052
|
+
|
|
1053
|
+
# Fix: Import plotly.express for color schemes
|
|
1054
|
+
try:
|
|
1055
|
+
import plotly.express as px
|
|
1056
|
+
except ImportError:
|
|
1057
|
+
# Plotly should be available if this module is imported
|
|
1058
|
+
raise ImportError( # noqa: B904
|
|
1059
|
+
"Plotly is required for visualization. Install with:\n pip install plotly"
|
|
1060
|
+
)
|