ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,618 @@
|
|
|
1
|
+
"""Feature interaction visualization functions.
|
|
2
|
+
|
|
3
|
+
This module provides functions for visualizing feature interaction analysis results
|
|
4
|
+
from analyze_interactions(), compute_shap_interactions(), and related functions.
|
|
5
|
+
|
|
6
|
+
All plot functions follow the standard API defined in docs/plot_api_standards.md:
|
|
7
|
+
- Consume results dicts from analyze_*() or compute_*() functions
|
|
8
|
+
- Return plotly.graph_objects.Figure instances
|
|
9
|
+
- Support theme customization via global or per-plot settings
|
|
10
|
+
- Use keyword-only arguments (after results)
|
|
11
|
+
- Provide comprehensive hover information and interactivity
|
|
12
|
+
|
|
13
|
+
Example workflow:
|
|
14
|
+
>>> from ml4t.diagnostic.evaluation import analyze_interactions, compute_shap_interactions
|
|
15
|
+
>>> from ml4t.diagnostic.visualization import (
|
|
16
|
+
... plot_interaction_bar,
|
|
17
|
+
... plot_interaction_heatmap,
|
|
18
|
+
... plot_interaction_network,
|
|
19
|
+
... set_plot_theme
|
|
20
|
+
... )
|
|
21
|
+
>>>
|
|
22
|
+
>>> # Analyze interactions
|
|
23
|
+
>>> results = analyze_interactions(model, X, y)
|
|
24
|
+
>>>
|
|
25
|
+
>>> # Or use SHAP directly
|
|
26
|
+
>>> shap_results = compute_shap_interactions(model, X, top_k=20)
|
|
27
|
+
>>>
|
|
28
|
+
>>> # Create visualizations
|
|
29
|
+
>>> fig_bar = plot_interaction_bar(shap_results, top_n=15)
|
|
30
|
+
>>> fig_heatmap = plot_interaction_heatmap(shap_results)
|
|
31
|
+
>>> fig_network = plot_interaction_network(shap_results, threshold=0.01)
|
|
32
|
+
>>>
|
|
33
|
+
>>> # Display or save
|
|
34
|
+
>>> fig_network.show()
|
|
35
|
+
>>> fig_heatmap.write_html("interactions_report.html")
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
from typing import Any
|
|
39
|
+
|
|
40
|
+
import numpy as np
|
|
41
|
+
import plotly.graph_objects as go
|
|
42
|
+
|
|
43
|
+
from ml4t.diagnostic.visualization.core import (
|
|
44
|
+
apply_responsive_layout,
|
|
45
|
+
get_color_scheme,
|
|
46
|
+
get_colorscale,
|
|
47
|
+
get_theme_config,
|
|
48
|
+
validate_plot_results,
|
|
49
|
+
validate_positive_int,
|
|
50
|
+
validate_theme,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
__all__ = [
|
|
54
|
+
"plot_interaction_bar",
|
|
55
|
+
"plot_interaction_heatmap",
|
|
56
|
+
"plot_interaction_network",
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def plot_interaction_bar(
|
|
61
|
+
results: dict[str, Any],
|
|
62
|
+
*,
|
|
63
|
+
title: str | None = None,
|
|
64
|
+
top_n: int | None = 20,
|
|
65
|
+
theme: str | None = None,
|
|
66
|
+
color_scheme: str | None = None,
|
|
67
|
+
width: int | None = None,
|
|
68
|
+
height: int | None = None,
|
|
69
|
+
show_values: bool = True,
|
|
70
|
+
) -> go.Figure:
|
|
71
|
+
"""Plot horizontal bar chart of top feature interactions.
|
|
72
|
+
|
|
73
|
+
Creates an interactive bar chart showing the strongest feature interactions
|
|
74
|
+
ranked by their interaction strength. Each bar represents a feature pair
|
|
75
|
+
with color-coding by strength.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
results : dict[str, Any]
|
|
80
|
+
Results from compute_shap_interactions() or analyze_interactions() containing:
|
|
81
|
+
- "top_interactions": list[tuple[str, str, float]] - Feature pairs with scores
|
|
82
|
+
OR
|
|
83
|
+
- "consensus_ranking": list[tuple[str, str, float, dict]] - From analyze_interactions()
|
|
84
|
+
title : str | None, optional
|
|
85
|
+
Plot title. If None, uses "Feature Interactions - Top Pairs"
|
|
86
|
+
top_n : int | None, optional
|
|
87
|
+
Number of top interactions to display. If None, shows all.
|
|
88
|
+
Default is 20 to avoid overcrowding.
|
|
89
|
+
theme : str | None, optional
|
|
90
|
+
Theme name ("default", "dark", "print", "presentation").
|
|
91
|
+
If None, uses current global theme.
|
|
92
|
+
color_scheme : str | None, optional
|
|
93
|
+
Color scheme for bars. If None, uses "viridis".
|
|
94
|
+
Recommended: "viridis", "cividis", "plasma", "oranges", "reds"
|
|
95
|
+
width : int | None, optional
|
|
96
|
+
Figure width in pixels. If None, uses theme default (typically 1000).
|
|
97
|
+
height : int | None, optional
|
|
98
|
+
Figure height in pixels. If None, auto-sizes based on interaction count
|
|
99
|
+
(25px per interaction + 100px padding).
|
|
100
|
+
show_values : bool, optional
|
|
101
|
+
Whether to show interaction values on bars. Default is True.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
go.Figure
|
|
106
|
+
Interactive Plotly figure with:
|
|
107
|
+
- Horizontal bars sorted by interaction strength
|
|
108
|
+
- Continuous color gradient indicating strength
|
|
109
|
+
- Hover info showing exact values
|
|
110
|
+
- Responsive layout for different screen sizes
|
|
111
|
+
|
|
112
|
+
Raises
|
|
113
|
+
------
|
|
114
|
+
ValueError
|
|
115
|
+
If results dict is missing required keys or has invalid structure.
|
|
116
|
+
TypeError
|
|
117
|
+
If parameters have incorrect types.
|
|
118
|
+
|
|
119
|
+
Examples
|
|
120
|
+
--------
|
|
121
|
+
>>> from ml4t.diagnostic.evaluation import compute_shap_interactions
|
|
122
|
+
>>> from ml4t.diagnostic.visualization import plot_interaction_bar
|
|
123
|
+
>>>
|
|
124
|
+
>>> # Compute SHAP interactions
|
|
125
|
+
>>> results = compute_shap_interactions(model, X, top_k=20)
|
|
126
|
+
>>>
|
|
127
|
+
>>> # Plot top 10 interactions
|
|
128
|
+
>>> fig = plot_interaction_bar(results, top_n=10)
|
|
129
|
+
>>> fig.show()
|
|
130
|
+
>>>
|
|
131
|
+
>>> # Custom styling
|
|
132
|
+
>>> fig = plot_interaction_bar(
|
|
133
|
+
... results,
|
|
134
|
+
... title="Strong Feature Interactions",
|
|
135
|
+
... top_n=15,
|
|
136
|
+
... theme="dark",
|
|
137
|
+
... color_scheme="plasma",
|
|
138
|
+
... height=700
|
|
139
|
+
... )
|
|
140
|
+
>>> fig.write_image("interactions.pdf")
|
|
141
|
+
|
|
142
|
+
Notes
|
|
143
|
+
-----
|
|
144
|
+
- Works with both compute_shap_interactions() and analyze_interactions() results
|
|
145
|
+
- Interaction strength is absolute magnitude (always positive)
|
|
146
|
+
- Pairs are deduplicated (A×B same as B×A)
|
|
147
|
+
- Use top_n to focus on strongest interactions
|
|
148
|
+
"""
|
|
149
|
+
# Validate inputs
|
|
150
|
+
theme = validate_theme(theme)
|
|
151
|
+
if top_n is not None:
|
|
152
|
+
validate_positive_int(top_n, "top_n")
|
|
153
|
+
|
|
154
|
+
# Extract interaction pairs - support both result formats
|
|
155
|
+
if "top_interactions" in results:
|
|
156
|
+
# From compute_shap_interactions() or single method
|
|
157
|
+
interactions = results["top_interactions"]
|
|
158
|
+
elif "consensus_ranking" in results:
|
|
159
|
+
# From analyze_interactions()
|
|
160
|
+
interactions = [
|
|
161
|
+
(pair[0], pair[1], pair[2]) # Extract first 3 elements
|
|
162
|
+
for pair in results["consensus_ranking"]
|
|
163
|
+
]
|
|
164
|
+
else:
|
|
165
|
+
raise ValueError(
|
|
166
|
+
"Results must contain 'top_interactions' (from compute_shap_interactions) "
|
|
167
|
+
"or 'consensus_ranking' (from analyze_interactions)"
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Limit to top N
|
|
171
|
+
if top_n is not None:
|
|
172
|
+
interactions = interactions[:top_n]
|
|
173
|
+
|
|
174
|
+
# Create labels and values
|
|
175
|
+
pair_labels = [f"{feat_i} × {feat_j}" for feat_i, feat_j, _ in interactions]
|
|
176
|
+
interaction_values = [abs(val) for _, _, val in interactions]
|
|
177
|
+
|
|
178
|
+
# Reverse for top-to-bottom display
|
|
179
|
+
pair_labels = pair_labels[::-1]
|
|
180
|
+
interaction_values = interaction_values[::-1]
|
|
181
|
+
|
|
182
|
+
# Get theme and colors
|
|
183
|
+
theme_config = get_theme_config(theme)
|
|
184
|
+
colors = get_colorscale(color_scheme or "viridis")
|
|
185
|
+
|
|
186
|
+
# Create figure
|
|
187
|
+
fig = go.Figure()
|
|
188
|
+
|
|
189
|
+
fig.add_trace(
|
|
190
|
+
go.Bar(
|
|
191
|
+
x=interaction_values,
|
|
192
|
+
y=pair_labels,
|
|
193
|
+
orientation="h",
|
|
194
|
+
marker={
|
|
195
|
+
"color": interaction_values,
|
|
196
|
+
"colorscale": colors,
|
|
197
|
+
"showscale": True,
|
|
198
|
+
"colorbar": {
|
|
199
|
+
"title": "Strength",
|
|
200
|
+
"tickformat": ".3f",
|
|
201
|
+
},
|
|
202
|
+
},
|
|
203
|
+
text=[f"{v:.3f}" for v in interaction_values] if show_values else None,
|
|
204
|
+
textposition="outside",
|
|
205
|
+
hovertemplate="<b>%{y}</b><br>Interaction: %{x:.4f}<extra></extra>",
|
|
206
|
+
)
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# Update layout
|
|
210
|
+
fig.update_layout(
|
|
211
|
+
title=title or "Feature Interactions - Top Pairs",
|
|
212
|
+
xaxis_title="Interaction Strength",
|
|
213
|
+
yaxis_title="Feature Pairs",
|
|
214
|
+
**theme_config["layout"],
|
|
215
|
+
width=width or 1000,
|
|
216
|
+
height=height or max(400, len(pair_labels) * 25 + 100),
|
|
217
|
+
showlegend=False,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Apply responsive layout
|
|
221
|
+
apply_responsive_layout(fig)
|
|
222
|
+
|
|
223
|
+
return fig
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def plot_interaction_heatmap(
|
|
227
|
+
results: dict[str, Any],
|
|
228
|
+
*,
|
|
229
|
+
title: str | None = None,
|
|
230
|
+
theme: str | None = None,
|
|
231
|
+
color_scheme: str | None = None,
|
|
232
|
+
width: int | None = None,
|
|
233
|
+
height: int | None = None,
|
|
234
|
+
show_values: bool = False, # False by default - can be crowded
|
|
235
|
+
) -> go.Figure:
|
|
236
|
+
"""Plot heatmap of feature interaction matrix.
|
|
237
|
+
|
|
238
|
+
Creates a symmetric heatmap showing pairwise feature interactions. The matrix
|
|
239
|
+
is symmetric (interaction(i,j) = interaction(j,i)). Diagonal elements represent
|
|
240
|
+
main effects (feature importance without interactions).
|
|
241
|
+
|
|
242
|
+
Parameters
|
|
243
|
+
----------
|
|
244
|
+
results : dict[str, Any]
|
|
245
|
+
Results from compute_shap_interactions() or similar containing:
|
|
246
|
+
- "interaction_matrix": np.ndarray - (n_features, n_features) matrix
|
|
247
|
+
- "feature_names": list[str] - Feature names for axis labels
|
|
248
|
+
title : str | None, optional
|
|
249
|
+
Plot title. If None, uses "Feature Interaction Matrix"
|
|
250
|
+
theme : str | None, optional
|
|
251
|
+
Theme name ("default", "dark", "print", "presentation").
|
|
252
|
+
If None, uses current global theme.
|
|
253
|
+
color_scheme : str | None, optional
|
|
254
|
+
Color scheme for heatmap. If None, uses "viridis".
|
|
255
|
+
Recommended: "viridis", "plasma", "inferno", "magma", "cividis"
|
|
256
|
+
width : int | None, optional
|
|
257
|
+
Figure width in pixels. If None, uses 800.
|
|
258
|
+
height : int | None, optional
|
|
259
|
+
Figure height in pixels. If None, uses 800.
|
|
260
|
+
show_values : bool, optional
|
|
261
|
+
Whether to show interaction values in cells. Default is False
|
|
262
|
+
(can be crowded for many features).
|
|
263
|
+
|
|
264
|
+
Returns
|
|
265
|
+
-------
|
|
266
|
+
go.Figure
|
|
267
|
+
Interactive Plotly heatmap with:
|
|
268
|
+
- Symmetric interaction matrix
|
|
269
|
+
- Continuous colorscale from weak to strong
|
|
270
|
+
- Optional cell annotations
|
|
271
|
+
- Hover showing feature pairs and values
|
|
272
|
+
|
|
273
|
+
Raises
|
|
274
|
+
------
|
|
275
|
+
ValueError
|
|
276
|
+
If results dict is missing required keys or has invalid structure.
|
|
277
|
+
TypeError
|
|
278
|
+
If parameters have incorrect types.
|
|
279
|
+
|
|
280
|
+
Examples
|
|
281
|
+
--------
|
|
282
|
+
>>> from ml4t.diagnostic.evaluation import compute_shap_interactions
|
|
283
|
+
>>> from ml4t.diagnostic.visualization import plot_interaction_heatmap
|
|
284
|
+
>>>
|
|
285
|
+
>>> # Compute interactions
|
|
286
|
+
>>> results = compute_shap_interactions(model, X)
|
|
287
|
+
>>>
|
|
288
|
+
>>> # Create heatmap
|
|
289
|
+
>>> fig = plot_interaction_heatmap(results)
|
|
290
|
+
>>> fig.show()
|
|
291
|
+
>>>
|
|
292
|
+
>>> # With annotations for small feature sets
|
|
293
|
+
>>> fig = plot_interaction_heatmap(
|
|
294
|
+
... results,
|
|
295
|
+
... show_values=True, # Show numbers in cells
|
|
296
|
+
... theme="print",
|
|
297
|
+
... color_scheme="viridis"
|
|
298
|
+
... )
|
|
299
|
+
|
|
300
|
+
Notes
|
|
301
|
+
-----
|
|
302
|
+
- Matrix is symmetric: interaction(i,j) = interaction(j,i)
|
|
303
|
+
- Diagonal elements are main effects (not interactions)
|
|
304
|
+
- Off-diagonal elements are pairwise interactions
|
|
305
|
+
- For many features (>20), consider hiding cell values (show_values=False)
|
|
306
|
+
- All values are absolute (non-negative)
|
|
307
|
+
"""
|
|
308
|
+
# Validate inputs
|
|
309
|
+
validate_plot_results(
|
|
310
|
+
results,
|
|
311
|
+
required_keys=["interaction_matrix", "feature_names"],
|
|
312
|
+
function_name="plot_interaction_heatmap",
|
|
313
|
+
)
|
|
314
|
+
theme = validate_theme(theme)
|
|
315
|
+
|
|
316
|
+
# Extract data
|
|
317
|
+
interaction_matrix = results["interaction_matrix"]
|
|
318
|
+
feature_names = results["feature_names"]
|
|
319
|
+
|
|
320
|
+
# Get theme and colors
|
|
321
|
+
theme_config = get_theme_config(theme)
|
|
322
|
+
colors = get_colorscale(color_scheme or "viridis")
|
|
323
|
+
|
|
324
|
+
# Create hover text
|
|
325
|
+
n_features = len(feature_names)
|
|
326
|
+
hover_text = []
|
|
327
|
+
for i in range(n_features):
|
|
328
|
+
row = []
|
|
329
|
+
for j in range(n_features):
|
|
330
|
+
value = interaction_matrix[i, j]
|
|
331
|
+
if i == j:
|
|
332
|
+
row.append(f"<b>{feature_names[i]}</b><br>Main Effect: {value:.4f}")
|
|
333
|
+
else:
|
|
334
|
+
row.append(
|
|
335
|
+
f"<b>{feature_names[i]}</b> × <b>{feature_names[j]}</b><br>Interaction: {value:.4f}"
|
|
336
|
+
)
|
|
337
|
+
hover_text.append(row)
|
|
338
|
+
|
|
339
|
+
# Create figure
|
|
340
|
+
fig = go.Figure()
|
|
341
|
+
|
|
342
|
+
fig.add_trace(
|
|
343
|
+
go.Heatmap(
|
|
344
|
+
z=interaction_matrix,
|
|
345
|
+
x=feature_names,
|
|
346
|
+
y=feature_names,
|
|
347
|
+
colorscale=colors,
|
|
348
|
+
colorbar={
|
|
349
|
+
"title": "Strength",
|
|
350
|
+
"tickformat": ".3f",
|
|
351
|
+
},
|
|
352
|
+
text=np.round(interaction_matrix, 3) if show_values else None,
|
|
353
|
+
texttemplate="%{text}" if show_values else None,
|
|
354
|
+
textfont={"size": 10},
|
|
355
|
+
hovertext=hover_text,
|
|
356
|
+
hovertemplate="%{hovertext}<extra></extra>",
|
|
357
|
+
)
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
# Update layout
|
|
361
|
+
fig.update_layout(
|
|
362
|
+
title=title or "Feature Interaction Matrix",
|
|
363
|
+
xaxis={
|
|
364
|
+
"title": "",
|
|
365
|
+
"side": "bottom",
|
|
366
|
+
"tickangle": -45 if len(feature_names) > 10 else 0,
|
|
367
|
+
},
|
|
368
|
+
yaxis={
|
|
369
|
+
"title": "",
|
|
370
|
+
"autorange": "reversed", # Top to bottom
|
|
371
|
+
},
|
|
372
|
+
**theme_config["layout"],
|
|
373
|
+
width=width or 800,
|
|
374
|
+
height=height or 800,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
# Apply responsive layout
|
|
378
|
+
apply_responsive_layout(fig)
|
|
379
|
+
|
|
380
|
+
return fig
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def plot_interaction_network(
|
|
384
|
+
results: dict[str, Any],
|
|
385
|
+
*,
|
|
386
|
+
title: str | None = None,
|
|
387
|
+
threshold: float | None = None,
|
|
388
|
+
top_n: int | None = None,
|
|
389
|
+
theme: str | None = None,
|
|
390
|
+
color_scheme: str | None = None,
|
|
391
|
+
width: int | None = None,
|
|
392
|
+
height: int | None = None,
|
|
393
|
+
node_size: int = 30,
|
|
394
|
+
show_edge_labels: bool = False,
|
|
395
|
+
) -> go.Figure:
|
|
396
|
+
"""Plot network graph of feature interactions.
|
|
397
|
+
|
|
398
|
+
Creates an interactive network visualization where:
|
|
399
|
+
- Nodes represent features
|
|
400
|
+
- Edges represent interactions
|
|
401
|
+
- Edge thickness indicates interaction strength
|
|
402
|
+
- Only significant interactions above threshold are shown
|
|
403
|
+
|
|
404
|
+
Parameters
|
|
405
|
+
----------
|
|
406
|
+
results : dict[str, Any]
|
|
407
|
+
Results from compute_shap_interactions() or analyze_interactions() containing:
|
|
408
|
+
- "top_interactions": list[tuple[str, str, float]] - Feature pairs
|
|
409
|
+
OR
|
|
410
|
+
- "interaction_matrix" and "feature_names" - Will extract top interactions
|
|
411
|
+
title : str | None, optional
|
|
412
|
+
Plot title. If None, uses "Feature Interaction Network"
|
|
413
|
+
threshold : float | None, optional
|
|
414
|
+
Minimum interaction strength to display. If None, uses adaptive threshold
|
|
415
|
+
(median of all interactions or top 20%, whichever is stricter).
|
|
416
|
+
top_n : int | None, optional
|
|
417
|
+
Maximum number of interactions to display. If None, shows all above threshold.
|
|
418
|
+
Useful to avoid cluttered networks.
|
|
419
|
+
theme : str | None, optional
|
|
420
|
+
Theme name ("default", "dark", "print", "presentation").
|
|
421
|
+
If None, uses current global theme.
|
|
422
|
+
color_scheme : str | None, optional
|
|
423
|
+
Color scheme for nodes. If None, uses "set2".
|
|
424
|
+
Recommended: "set2", "set3", "pastel", "bold"
|
|
425
|
+
width : int | None, optional
|
|
426
|
+
Figure width in pixels. If None, uses 1000.
|
|
427
|
+
height : int | None, optional
|
|
428
|
+
Figure height in pixels. If None, uses 800.
|
|
429
|
+
node_size : int, optional
|
|
430
|
+
Size of nodes in pixels. Default is 30.
|
|
431
|
+
show_edge_labels : bool, optional
|
|
432
|
+
Whether to show interaction values on edges. Default is False
|
|
433
|
+
(can be cluttered).
|
|
434
|
+
|
|
435
|
+
Returns
|
|
436
|
+
-------
|
|
437
|
+
go.Figure
|
|
438
|
+
Interactive Plotly network graph with:
|
|
439
|
+
- Nodes positioned using force-directed layout
|
|
440
|
+
- Edge thickness proportional to interaction strength
|
|
441
|
+
- Optional edge labels showing values
|
|
442
|
+
- Hover info for nodes and edges
|
|
443
|
+
- Pan/zoom capability
|
|
444
|
+
|
|
445
|
+
Raises
|
|
446
|
+
------
|
|
447
|
+
ValueError
|
|
448
|
+
If results dict is missing required keys or has invalid structure.
|
|
449
|
+
If no interactions remain after filtering.
|
|
450
|
+
TypeError
|
|
451
|
+
If parameters have incorrect types.
|
|
452
|
+
|
|
453
|
+
Examples
|
|
454
|
+
--------
|
|
455
|
+
>>> from ml4t.diagnostic.evaluation import compute_shap_interactions
|
|
456
|
+
>>> from ml4t.diagnostic.visualization import plot_interaction_network
|
|
457
|
+
>>>
|
|
458
|
+
>>> # Compute interactions
|
|
459
|
+
>>> results = compute_shap_interactions(model, X, top_k=30)
|
|
460
|
+
>>>
|
|
461
|
+
>>> # Create network showing only strong interactions
|
|
462
|
+
>>> fig = plot_interaction_network(
|
|
463
|
+
... results,
|
|
464
|
+
... threshold=0.05, # Show only interactions > 0.05
|
|
465
|
+
... top_n=20 # Limit to top 20
|
|
466
|
+
... )
|
|
467
|
+
>>> fig.show()
|
|
468
|
+
>>>
|
|
469
|
+
>>> # Show edge labels
|
|
470
|
+
>>> fig = plot_interaction_network(
|
|
471
|
+
... results,
|
|
472
|
+
... show_edge_labels=True,
|
|
473
|
+
... theme="dark"
|
|
474
|
+
... )
|
|
475
|
+
|
|
476
|
+
Notes
|
|
477
|
+
-----
|
|
478
|
+
- Network layout uses spring/force-directed algorithm
|
|
479
|
+
- Isolated nodes (no interactions) are excluded
|
|
480
|
+
- Edge thickness is proportional to interaction strength
|
|
481
|
+
- For complex networks (>50 edges), consider increasing threshold or using top_n
|
|
482
|
+
- Use threshold and top_n together for best control
|
|
483
|
+
"""
|
|
484
|
+
# Validate inputs
|
|
485
|
+
theme = validate_theme(theme)
|
|
486
|
+
if top_n is not None:
|
|
487
|
+
validate_positive_int(top_n, "top_n")
|
|
488
|
+
|
|
489
|
+
# Extract interactions
|
|
490
|
+
if "top_interactions" in results:
|
|
491
|
+
interactions = results["top_interactions"]
|
|
492
|
+
elif "interaction_matrix" in results and "feature_names" in results:
|
|
493
|
+
# Convert matrix to interaction list
|
|
494
|
+
matrix = results["interaction_matrix"]
|
|
495
|
+
feature_names = results["feature_names"]
|
|
496
|
+
n_features = len(feature_names)
|
|
497
|
+
|
|
498
|
+
interactions = []
|
|
499
|
+
for i in range(n_features):
|
|
500
|
+
for j in range(i + 1, n_features): # Upper triangle only
|
|
501
|
+
interactions.append((feature_names[i], feature_names[j], matrix[i, j]))
|
|
502
|
+
|
|
503
|
+
# Sort by strength
|
|
504
|
+
interactions.sort(key=lambda x: abs(x[2]), reverse=True)
|
|
505
|
+
else:
|
|
506
|
+
raise ValueError(
|
|
507
|
+
"Results must contain 'top_interactions' or 'interaction_matrix' + 'feature_names'"
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
# Apply threshold
|
|
511
|
+
if threshold is None:
|
|
512
|
+
# Adaptive threshold: median or top 20%
|
|
513
|
+
values = [abs(val) for _, _, val in interactions]
|
|
514
|
+
median_threshold = np.median(values)
|
|
515
|
+
percentile_threshold = np.percentile(values, 80)
|
|
516
|
+
threshold = max(median_threshold, percentile_threshold)
|
|
517
|
+
|
|
518
|
+
filtered_interactions = [(f1, f2, val) for f1, f2, val in interactions if abs(val) >= threshold]
|
|
519
|
+
|
|
520
|
+
# Apply top_n limit
|
|
521
|
+
if top_n is not None:
|
|
522
|
+
filtered_interactions = filtered_interactions[:top_n]
|
|
523
|
+
|
|
524
|
+
if len(filtered_interactions) == 0:
|
|
525
|
+
raise ValueError(
|
|
526
|
+
f"No interactions above threshold {threshold:.4f}. Try lowering threshold or increasing top_n."
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
# Build node set
|
|
530
|
+
node_set: set[str] = set()
|
|
531
|
+
for f1, f2, _ in filtered_interactions:
|
|
532
|
+
node_set.add(f1)
|
|
533
|
+
node_set.add(f2)
|
|
534
|
+
nodes = sorted(node_set)
|
|
535
|
+
node_indices = {node: i for i, node in enumerate(nodes)}
|
|
536
|
+
|
|
537
|
+
# Simple circular layout for nodes
|
|
538
|
+
n_nodes = len(nodes)
|
|
539
|
+
angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False)
|
|
540
|
+
radius = 1.0
|
|
541
|
+
|
|
542
|
+
node_x = radius * np.cos(angles)
|
|
543
|
+
node_y = radius * np.sin(angles)
|
|
544
|
+
|
|
545
|
+
# Get theme and colors
|
|
546
|
+
get_theme_config(theme)
|
|
547
|
+
node_colors = get_color_scheme(color_scheme or "set2")
|
|
548
|
+
|
|
549
|
+
# Create figure
|
|
550
|
+
fig = go.Figure()
|
|
551
|
+
|
|
552
|
+
# Add edges
|
|
553
|
+
max_interaction = max(abs(val) for _, _, val in filtered_interactions)
|
|
554
|
+
for f1, f2, val in filtered_interactions:
|
|
555
|
+
i1 = node_indices[f1]
|
|
556
|
+
i2 = node_indices[f2]
|
|
557
|
+
|
|
558
|
+
# Edge thickness proportional to interaction strength
|
|
559
|
+
edge_width = 1 + 5 * (abs(val) / max_interaction)
|
|
560
|
+
|
|
561
|
+
fig.add_trace(
|
|
562
|
+
go.Scatter(
|
|
563
|
+
x=[node_x[i1], node_x[i2]],
|
|
564
|
+
y=[node_y[i1], node_y[i2]],
|
|
565
|
+
mode="lines",
|
|
566
|
+
line={"width": edge_width, "color": "rgba(125,125,125,0.5)"},
|
|
567
|
+
hoverinfo="text",
|
|
568
|
+
hovertext=f"{f1} × {f2}<br>Interaction: {abs(val):.4f}",
|
|
569
|
+
showlegend=False,
|
|
570
|
+
)
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
# Optional edge labels
|
|
574
|
+
if show_edge_labels:
|
|
575
|
+
mid_x = (node_x[i1] + node_x[i2]) / 2
|
|
576
|
+
mid_y = (node_y[i1] + node_y[i2]) / 2
|
|
577
|
+
fig.add_annotation(
|
|
578
|
+
x=mid_x,
|
|
579
|
+
y=mid_y,
|
|
580
|
+
text=f"{abs(val):.2f}",
|
|
581
|
+
showarrow=False,
|
|
582
|
+
font={"size": 8},
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
# Add nodes
|
|
586
|
+
fig.add_trace(
|
|
587
|
+
go.Scatter(
|
|
588
|
+
x=node_x,
|
|
589
|
+
y=node_y,
|
|
590
|
+
mode="markers+text",
|
|
591
|
+
marker={
|
|
592
|
+
"size": node_size,
|
|
593
|
+
"color": [node_colors[i % len(node_colors)] for i in range(n_nodes)],
|
|
594
|
+
"line": {"width": 2, "color": "white"},
|
|
595
|
+
},
|
|
596
|
+
text=nodes,
|
|
597
|
+
textposition="top center",
|
|
598
|
+
textfont={"size": 10},
|
|
599
|
+
hoverinfo="text",
|
|
600
|
+
hovertext=nodes,
|
|
601
|
+
showlegend=False,
|
|
602
|
+
)
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# Update layout (using simpler approach to avoid theme conflicts)
|
|
606
|
+
fig.update_layout(
|
|
607
|
+
title=title or "Feature Interaction Network",
|
|
608
|
+
width=width or 1000,
|
|
609
|
+
height=height or 800,
|
|
610
|
+
xaxis={"showgrid": False, "zeroline": False, "showticklabels": False},
|
|
611
|
+
yaxis={"showgrid": False, "zeroline": False, "showticklabels": False},
|
|
612
|
+
hovermode="closest",
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
# Apply responsive layout
|
|
616
|
+
apply_responsive_layout(fig)
|
|
617
|
+
|
|
618
|
+
return fig
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Portfolio visualization module.
|
|
2
|
+
|
|
3
|
+
Plotly-based interactive visualizations for portfolio analysis.
|
|
4
|
+
Replacement for pyfolio's matplotlib-based plots.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .dashboard import create_portfolio_dashboard
|
|
8
|
+
from .drawdown_plots import (
|
|
9
|
+
plot_drawdown_periods,
|
|
10
|
+
plot_drawdown_underwater,
|
|
11
|
+
)
|
|
12
|
+
from .returns_plots import (
|
|
13
|
+
plot_annual_returns_bar,
|
|
14
|
+
plot_cumulative_returns,
|
|
15
|
+
plot_monthly_returns_heatmap,
|
|
16
|
+
plot_returns_distribution,
|
|
17
|
+
plot_rolling_returns,
|
|
18
|
+
)
|
|
19
|
+
from .risk_plots import (
|
|
20
|
+
plot_rolling_beta,
|
|
21
|
+
plot_rolling_sharpe,
|
|
22
|
+
plot_rolling_volatility,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
# Returns
|
|
27
|
+
"plot_cumulative_returns",
|
|
28
|
+
"plot_rolling_returns",
|
|
29
|
+
"plot_annual_returns_bar",
|
|
30
|
+
"plot_monthly_returns_heatmap",
|
|
31
|
+
"plot_returns_distribution",
|
|
32
|
+
# Risk
|
|
33
|
+
"plot_rolling_volatility",
|
|
34
|
+
"plot_rolling_sharpe",
|
|
35
|
+
"plot_rolling_beta",
|
|
36
|
+
# Drawdown
|
|
37
|
+
"plot_drawdown_underwater",
|
|
38
|
+
"plot_drawdown_periods",
|
|
39
|
+
# Dashboard
|
|
40
|
+
"create_portfolio_dashboard",
|
|
41
|
+
]
|