ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,801 @@
|
|
|
1
|
+
"""Feature importance dashboard for comprehensive analysis.
|
|
2
|
+
|
|
3
|
+
This module provides the FeatureImportanceDashboard class for creating
|
|
4
|
+
rich, interactive dashboards exploring feature importance across multiple methods.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Any, Literal
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import plotly.graph_objects as go
|
|
13
|
+
from plotly.subplots import make_subplots
|
|
14
|
+
|
|
15
|
+
from ..data_extraction import (
|
|
16
|
+
ImportanceVizData,
|
|
17
|
+
extract_importance_viz_data,
|
|
18
|
+
extract_interaction_viz_data,
|
|
19
|
+
)
|
|
20
|
+
from .base import BaseDashboard, DashboardSection, get_theme
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class FeatureImportanceDashboard(BaseDashboard):
|
|
24
|
+
"""Interactive dashboard for comprehensive feature importance analysis.
|
|
25
|
+
|
|
26
|
+
Provides rich exploration of feature importance with 4-tab architecture:
|
|
27
|
+
- **Overview Tab**: Consensus ranking, method agreement heatmap, key insights
|
|
28
|
+
- **Method Comparison Tab**: Aligned feature rankings across methods, stability analysis
|
|
29
|
+
- **Feature Details Tab**: Searchable/filterable table with all features and metrics
|
|
30
|
+
- **Interactions Tab**: Network visualization or top pairs table (adaptive based on feature count)
|
|
31
|
+
|
|
32
|
+
All visualizations are interactive (zoom, pan, hover for details).
|
|
33
|
+
Real-time search filtering available in Feature Details tab.
|
|
34
|
+
|
|
35
|
+
Examples
|
|
36
|
+
--------
|
|
37
|
+
**Basic Usage: Single Method**
|
|
38
|
+
|
|
39
|
+
>>> from sklearn.ensemble import RandomForestClassifier
|
|
40
|
+
>>> from ml4t.diagnostic.evaluation import analyze_ml_importance
|
|
41
|
+
>>> from ml4t.diagnostic.visualization import FeatureImportanceDashboard
|
|
42
|
+
>>>
|
|
43
|
+
>>> # Train model
|
|
44
|
+
>>> model = RandomForestClassifier(n_estimators=100, random_state=42)
|
|
45
|
+
>>> model.fit(X_train, y_train)
|
|
46
|
+
>>>
|
|
47
|
+
>>> # Analyze importance (single method)
|
|
48
|
+
>>> results = analyze_ml_importance(
|
|
49
|
+
... model, X_train, y_train,
|
|
50
|
+
... methods=['mdi'],
|
|
51
|
+
... feature_names=X_train.columns
|
|
52
|
+
... )
|
|
53
|
+
>>>
|
|
54
|
+
>>> # Create and save dashboard
|
|
55
|
+
>>> dashboard = FeatureImportanceDashboard(title="MDI Analysis")
|
|
56
|
+
>>> dashboard.save("mdi_dashboard.html", results)
|
|
57
|
+
|
|
58
|
+
**Multiple Methods with Permutation Repeats**
|
|
59
|
+
|
|
60
|
+
>>> # Run multiple methods for comparison
|
|
61
|
+
>>> results = analyze_ml_importance(
|
|
62
|
+
... model, X_train, y_train,
|
|
63
|
+
... methods=['mdi', 'pfi'],
|
|
64
|
+
... pfi_n_repeats=10, # Get uncertainty estimates
|
|
65
|
+
... feature_names=X_train.columns
|
|
66
|
+
... )
|
|
67
|
+
>>>
|
|
68
|
+
>>> # Dark theme with custom top N
|
|
69
|
+
>>> dashboard = FeatureImportanceDashboard(
|
|
70
|
+
... title="Feature Importance: MDI vs PFI",
|
|
71
|
+
... theme="dark",
|
|
72
|
+
... n_top_features=15
|
|
73
|
+
... )
|
|
74
|
+
>>> dashboard.save("multi_method_dashboard.html", results)
|
|
75
|
+
|
|
76
|
+
**With Feature Interactions (SHAP)**
|
|
77
|
+
|
|
78
|
+
>>> from ml4t.diagnostic.evaluation import compute_shap_interactions
|
|
79
|
+
>>>
|
|
80
|
+
>>> # Compute importance
|
|
81
|
+
>>> importance_results = analyze_ml_importance(
|
|
82
|
+
... model, X_train, y_train,
|
|
83
|
+
... methods=['mdi', 'pfi', 'shap'],
|
|
84
|
+
... feature_names=X_train.columns
|
|
85
|
+
... )
|
|
86
|
+
>>>
|
|
87
|
+
>>> # Compute interactions (adds 3 more visualizations)
|
|
88
|
+
>>> interaction_results = compute_shap_interactions(
|
|
89
|
+
... model, X_train,
|
|
90
|
+
... feature_names=X_train.columns
|
|
91
|
+
... )
|
|
92
|
+
>>>
|
|
93
|
+
>>> # Create dashboard with both
|
|
94
|
+
>>> dashboard = FeatureImportanceDashboard(
|
|
95
|
+
... title="Full Feature Analysis with Interactions"
|
|
96
|
+
... )
|
|
97
|
+
>>> html = dashboard.generate(
|
|
98
|
+
... analysis_results=importance_results,
|
|
99
|
+
... interaction_results=interaction_results
|
|
100
|
+
... )
|
|
101
|
+
>>> with open("full_dashboard.html", "w") as f:
|
|
102
|
+
... f.write(html)
|
|
103
|
+
|
|
104
|
+
Notes
|
|
105
|
+
-----
|
|
106
|
+
- Dashboard requires results from `analyze_ml_importance()`
|
|
107
|
+
- Interaction visualizations only appear if `interaction_results` provided
|
|
108
|
+
- PFI distribution plots only shown if `pfi_n_repeats > 1`
|
|
109
|
+
- All Plotly visualizations are interactive (zoom, pan, hover)
|
|
110
|
+
|
|
111
|
+
See Also
|
|
112
|
+
--------
|
|
113
|
+
analyze_ml_importance : Compute feature importance across methods
|
|
114
|
+
compute_shap_interactions : Compute pairwise feature interactions
|
|
115
|
+
FeatureInteractionDashboard : Standalone interaction analysis
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
# Class constants for visualization thresholds
|
|
119
|
+
INTERACTION_NETWORK_THRESHOLD = 20 # Show network if ≤20 features, else table
|
|
120
|
+
INTERACTION_TOP_EDGES = 20 # Number of top interaction pairs to display
|
|
121
|
+
INTERACTION_MATRIX_SIZE = 15 # Max features for interaction matrix heatmap
|
|
122
|
+
|
|
123
|
+
def __init__(
|
|
124
|
+
self,
|
|
125
|
+
title: str = "Feature Importance Analysis",
|
|
126
|
+
theme: Literal["light", "dark"] = "light",
|
|
127
|
+
width: int | None = None,
|
|
128
|
+
height: int | None = None,
|
|
129
|
+
n_top_features: int = 10,
|
|
130
|
+
):
|
|
131
|
+
"""Initialize Feature Importance Dashboard.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
title : str, default="Feature Importance Analysis"
|
|
136
|
+
Dashboard title
|
|
137
|
+
theme : {'light', 'dark'}, default='light'
|
|
138
|
+
Visual theme
|
|
139
|
+
width : int, optional
|
|
140
|
+
Dashboard width in pixels
|
|
141
|
+
height : int, optional
|
|
142
|
+
Dashboard height in pixels
|
|
143
|
+
n_top_features : int, default=10
|
|
144
|
+
Number of top features to highlight in visualizations
|
|
145
|
+
"""
|
|
146
|
+
super().__init__(title, theme, width, height)
|
|
147
|
+
self.n_top_features = n_top_features
|
|
148
|
+
|
|
149
|
+
def generate(
|
|
150
|
+
self,
|
|
151
|
+
analysis_results: dict[str, Any],
|
|
152
|
+
interaction_results: dict[str, Any] | None = None,
|
|
153
|
+
**_kwargs,
|
|
154
|
+
) -> str:
|
|
155
|
+
"""Generate complete dashboard HTML.
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
analysis_results : dict
|
|
160
|
+
Results from analyze_ml_importance()
|
|
161
|
+
interaction_results : dict, optional
|
|
162
|
+
Results from compute_shap_interactions() to include interaction analysis
|
|
163
|
+
**kwargs
|
|
164
|
+
Additional parameters (currently unused)
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
str
|
|
169
|
+
Complete HTML document
|
|
170
|
+
"""
|
|
171
|
+
# Extract structured data
|
|
172
|
+
viz_data = extract_importance_viz_data(analysis_results)
|
|
173
|
+
|
|
174
|
+
# Create tabbed layout with improved organization
|
|
175
|
+
self._create_tabbed_layout(viz_data, interaction_results)
|
|
176
|
+
|
|
177
|
+
# Compose HTML
|
|
178
|
+
return self._compose_html()
|
|
179
|
+
|
|
180
|
+
def _create_tabbed_layout(
|
|
181
|
+
self, viz_data: ImportanceVizData, interaction_results: dict[str, Any] | None = None
|
|
182
|
+
) -> None:
|
|
183
|
+
"""Create tabbed dashboard layout with improved organization."""
|
|
184
|
+
# Create tabs
|
|
185
|
+
tabs = [
|
|
186
|
+
("overview", "Overview"),
|
|
187
|
+
("methods", "Method Comparison"),
|
|
188
|
+
("features", "Feature Details"),
|
|
189
|
+
]
|
|
190
|
+
|
|
191
|
+
# Add interactions tab if we have interaction results
|
|
192
|
+
if interaction_results is not None:
|
|
193
|
+
tabs.append(("interactions", "Interactions"))
|
|
194
|
+
|
|
195
|
+
# Build tab content
|
|
196
|
+
tab_contents = {
|
|
197
|
+
"overview": self._create_overview_tab(viz_data),
|
|
198
|
+
"methods": self._create_method_comparison_tab(viz_data),
|
|
199
|
+
"features": self._create_feature_details_tab(viz_data),
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
if interaction_results is not None:
|
|
203
|
+
tab_contents["interactions"] = self._create_interactions_tab(
|
|
204
|
+
viz_data, interaction_results
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Build tab navigation buttons
|
|
208
|
+
tab_buttons = "".join(
|
|
209
|
+
[
|
|
210
|
+
f'<button class="tab-button{" active" if i == 0 else ""}" '
|
|
211
|
+
f"onclick=\"switchTab(event, '{tab_id}')\">{tab_name}</button>"
|
|
212
|
+
for i, (tab_id, tab_name) in enumerate(tabs)
|
|
213
|
+
]
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Build tab content divs
|
|
217
|
+
tab_divs = "".join(
|
|
218
|
+
[
|
|
219
|
+
f'<div id="{tab_id}" class="tab-content{" active" if i == 0 else ""}">{tab_contents[tab_id]}</div>'
|
|
220
|
+
for i, (tab_id, _) in enumerate(tabs)
|
|
221
|
+
]
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Compose complete tabbed layout
|
|
225
|
+
html_content = f"""
|
|
226
|
+
<div class="tab-navigation">
|
|
227
|
+
{tab_buttons}
|
|
228
|
+
</div>
|
|
229
|
+
{tab_divs}
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
# Create single section with tabbed content
|
|
233
|
+
section = DashboardSection(
|
|
234
|
+
title="Feature Importance Analysis", description="", content=html_content
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
self.sections.append(section)
|
|
238
|
+
|
|
239
|
+
def _create_overview_tab(self, viz_data: ImportanceVizData) -> str:
|
|
240
|
+
"""Create Overview tab with executive summary."""
|
|
241
|
+
summary = viz_data["summary"]
|
|
242
|
+
per_feature = viz_data["per_feature"]
|
|
243
|
+
per_method = viz_data["per_method"]
|
|
244
|
+
method_comparison = viz_data["method_comparison"]
|
|
245
|
+
llm_context = viz_data["llm_context"]
|
|
246
|
+
|
|
247
|
+
methods = list(per_method.keys())
|
|
248
|
+
method_names_display = ", ".join([m.upper() for m in methods])
|
|
249
|
+
|
|
250
|
+
# Metric cards
|
|
251
|
+
html = ["<h2>Overview</h2>"]
|
|
252
|
+
html.append(f"""
|
|
253
|
+
<div class="metric-grid">
|
|
254
|
+
<div class="metric-card">
|
|
255
|
+
<div class="metric-label">Features Analyzed</div>
|
|
256
|
+
<div class="metric-value">{summary["n_features"]}</div>
|
|
257
|
+
</div>
|
|
258
|
+
<div class="metric-card">
|
|
259
|
+
<div class="metric-label">Methods Used</div>
|
|
260
|
+
<div class="metric-value">{summary["n_methods"]}</div>
|
|
261
|
+
<div class="metric-sublabel">({method_names_display})</div>
|
|
262
|
+
</div>
|
|
263
|
+
<div class="metric-card">
|
|
264
|
+
<div class="metric-label">Top Feature</div>
|
|
265
|
+
<div class="metric-value">{summary["top_feature"]}</div>
|
|
266
|
+
</div>
|
|
267
|
+
<div class="metric-card">
|
|
268
|
+
<div class="metric-label" title="Average rank correlation between methods (1.0 = perfect agreement, 0.0 = no agreement)">
|
|
269
|
+
Method Agreement
|
|
270
|
+
<span style="font-size: 0.8em; color: #666; cursor: help;">ⓘ</span>
|
|
271
|
+
</div>
|
|
272
|
+
<div class="metric-value">{summary["avg_method_agreement"]:.2f}</div>
|
|
273
|
+
<div class="metric-sublabel">
|
|
274
|
+
{"High" if summary["avg_method_agreement"] > 0.7 else "Medium" if summary["avg_method_agreement"] > 0.4 else "Low"} Agreement
|
|
275
|
+
</div>
|
|
276
|
+
</div>
|
|
277
|
+
</div>
|
|
278
|
+
""")
|
|
279
|
+
|
|
280
|
+
# Consensus ranking chart
|
|
281
|
+
consensus_ranking = summary["consensus_ranking"][: self.n_top_features]
|
|
282
|
+
consensus_scores = [
|
|
283
|
+
per_feature[feat]["consensus_score"] * 100 for feat in consensus_ranking
|
|
284
|
+
]
|
|
285
|
+
|
|
286
|
+
fig_consensus = go.Figure()
|
|
287
|
+
fig_consensus.add_trace(
|
|
288
|
+
go.Bar(
|
|
289
|
+
y=consensus_ranking,
|
|
290
|
+
x=consensus_scores,
|
|
291
|
+
orientation="h",
|
|
292
|
+
marker={"color": consensus_scores, "colorscale": "Blues", "showscale": False},
|
|
293
|
+
text=[f"{score:.2f}%" for score in consensus_scores],
|
|
294
|
+
textposition="auto",
|
|
295
|
+
)
|
|
296
|
+
)
|
|
297
|
+
fig_consensus.update_layout(
|
|
298
|
+
title=f"Top {self.n_top_features} Features (Consensus Ranking)",
|
|
299
|
+
xaxis_title="Consensus Importance (%)",
|
|
300
|
+
yaxis_title="Feature",
|
|
301
|
+
yaxis={"autorange": "reversed"},
|
|
302
|
+
template=get_theme(self.theme)["template"],
|
|
303
|
+
height=max(400, len(consensus_ranking) * 40),
|
|
304
|
+
margin={"l": 150, "r": 50, "t": 80, "b": 80},
|
|
305
|
+
)
|
|
306
|
+
html.append(fig_consensus.to_html(include_plotlyjs=False, div_id="plot-consensus"))
|
|
307
|
+
|
|
308
|
+
# Insights panel
|
|
309
|
+
html.append("""
|
|
310
|
+
<div class="insights-panel">
|
|
311
|
+
<h3>Key Insights</h3>
|
|
312
|
+
<ul>
|
|
313
|
+
""")
|
|
314
|
+
for insight in llm_context["key_insights"]:
|
|
315
|
+
html.append(f"<li>{insight}</li>")
|
|
316
|
+
html.append("</ul></div>")
|
|
317
|
+
|
|
318
|
+
# Method agreement heatmap (if multiple methods)
|
|
319
|
+
if len(methods) > 1:
|
|
320
|
+
corr_matrix = method_comparison["correlation_matrix"]
|
|
321
|
+
corr_methods = method_comparison["correlation_methods"]
|
|
322
|
+
|
|
323
|
+
fig_corr = go.Figure(
|
|
324
|
+
data=go.Heatmap(
|
|
325
|
+
z=corr_matrix,
|
|
326
|
+
x=[m.upper() for m in corr_methods],
|
|
327
|
+
y=[m.upper() for m in corr_methods],
|
|
328
|
+
colorscale="RdBu",
|
|
329
|
+
zmid=0,
|
|
330
|
+
zmin=-1,
|
|
331
|
+
zmax=1,
|
|
332
|
+
text=[[f"{val:.2f}" for val in row] for row in corr_matrix],
|
|
333
|
+
texttemplate="%{text}",
|
|
334
|
+
textfont={"size": 14},
|
|
335
|
+
colorbar={"title": "Correlation"},
|
|
336
|
+
)
|
|
337
|
+
)
|
|
338
|
+
fig_corr.update_layout(
|
|
339
|
+
title="Method Agreement (Rank Correlation)",
|
|
340
|
+
template=get_theme(self.theme)["template"],
|
|
341
|
+
height=400,
|
|
342
|
+
margin={"l": 100, "r": 100, "t": 100, "b": 80},
|
|
343
|
+
)
|
|
344
|
+
html.append(fig_corr.to_html(include_plotlyjs=False, div_id="plot-method-corr"))
|
|
345
|
+
|
|
346
|
+
return "".join(html)
|
|
347
|
+
|
|
348
|
+
def _create_method_comparison_tab(self, viz_data: ImportanceVizData) -> str:
|
|
349
|
+
"""Create Method Comparison tab with ALIGNED features across methods."""
|
|
350
|
+
summary = viz_data["summary"]
|
|
351
|
+
per_method = viz_data["per_method"]
|
|
352
|
+
uncertainty = viz_data["uncertainty"]
|
|
353
|
+
|
|
354
|
+
methods = list(per_method.keys())
|
|
355
|
+
n_methods = len(methods)
|
|
356
|
+
|
|
357
|
+
html = ["<h2>Method Comparison</h2>"]
|
|
358
|
+
|
|
359
|
+
# Explanation section
|
|
360
|
+
html.append("""
|
|
361
|
+
<div class="section-description">
|
|
362
|
+
<p><strong>What you're seeing:</strong> The same top features (consensus-ranked) shown across all methods. This reveals where methods agree and disagree.</p>
|
|
363
|
+
<p><strong>Look for:</strong></p>
|
|
364
|
+
<ul>
|
|
365
|
+
<li>Similar bar heights across methods → strong agreement on importance</li>
|
|
366
|
+
<li>Very different heights → methods disagree (investigate further)</li>
|
|
367
|
+
<li>Large error bars → unstable importance estimates</li>
|
|
368
|
+
</ul>
|
|
369
|
+
</div>
|
|
370
|
+
""")
|
|
371
|
+
|
|
372
|
+
if n_methods > 1:
|
|
373
|
+
# Get consensus top features (SAME features shown across all methods)
|
|
374
|
+
consensus_features = summary["consensus_ranking"][: self.n_top_features]
|
|
375
|
+
|
|
376
|
+
# Side-by-side method comparison with ALIGNED features
|
|
377
|
+
fig_methods = make_subplots(
|
|
378
|
+
rows=1,
|
|
379
|
+
cols=n_methods,
|
|
380
|
+
subplot_titles=[m.upper() for m in methods],
|
|
381
|
+
shared_yaxes=True,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
for col_idx, method_name in enumerate(methods, start=1):
|
|
385
|
+
method_data = per_method[method_name]
|
|
386
|
+
|
|
387
|
+
# Get importance values for consensus features (in consensus order)
|
|
388
|
+
importances = [
|
|
389
|
+
method_data["importances"].get(feat, 0) * 100 for feat in consensus_features
|
|
390
|
+
]
|
|
391
|
+
|
|
392
|
+
# Get error bars if available
|
|
393
|
+
error_y = None
|
|
394
|
+
if method_data["std"] is not None:
|
|
395
|
+
error_std = [
|
|
396
|
+
method_data["std"].get(feat, 0) * 100 for feat in consensus_features
|
|
397
|
+
]
|
|
398
|
+
error_y = {"type": "data", "array": error_std, "visible": True}
|
|
399
|
+
|
|
400
|
+
fig_methods.add_trace(
|
|
401
|
+
go.Bar(
|
|
402
|
+
y=consensus_features, # SAME features for all methods
|
|
403
|
+
x=importances,
|
|
404
|
+
orientation="h",
|
|
405
|
+
error_x=error_y,
|
|
406
|
+
text=[f"{imp:.2f}%" for imp in importances],
|
|
407
|
+
textposition="auto",
|
|
408
|
+
showlegend=False,
|
|
409
|
+
),
|
|
410
|
+
row=1,
|
|
411
|
+
col=col_idx,
|
|
412
|
+
)
|
|
413
|
+
fig_methods.update_xaxes(title_text="Importance (%)", row=1, col=col_idx)
|
|
414
|
+
|
|
415
|
+
fig_methods.update_yaxes(title_text="Feature", autorange="reversed", row=1, col=1)
|
|
416
|
+
fig_methods.update_layout(
|
|
417
|
+
title=f"Top {self.n_top_features} Features: How Each Method Ranks Them",
|
|
418
|
+
template=get_theme(self.theme)["template"],
|
|
419
|
+
height=max(500, self.n_top_features * 40),
|
|
420
|
+
margin={"l": 150, "r": 50, "t": 100, "b": 80},
|
|
421
|
+
)
|
|
422
|
+
html.append(fig_methods.to_html(include_plotlyjs=False, div_id="plot-methods"))
|
|
423
|
+
|
|
424
|
+
# Add stability analysis if available
|
|
425
|
+
has_uncertainty = bool(uncertainty.get("coefficient_of_variation"))
|
|
426
|
+
if has_uncertainty:
|
|
427
|
+
html.append('<h3 style="margin-top: 30px;">Stability Analysis</h3>')
|
|
428
|
+
|
|
429
|
+
consensus_ranking = summary["consensus_ranking"][: self.n_top_features]
|
|
430
|
+
|
|
431
|
+
# CV plot
|
|
432
|
+
cv_data = uncertainty.get("coefficient_of_variation", {}).get("pfi", {})
|
|
433
|
+
if cv_data:
|
|
434
|
+
cv_values = [cv_data.get(feat, 0) for feat in consensus_ranking]
|
|
435
|
+
fig_cv = go.Figure()
|
|
436
|
+
fig_cv.add_trace(
|
|
437
|
+
go.Bar(
|
|
438
|
+
y=consensus_ranking,
|
|
439
|
+
x=cv_values,
|
|
440
|
+
orientation="h",
|
|
441
|
+
marker={
|
|
442
|
+
"color": cv_values,
|
|
443
|
+
"colorscale": "Reds",
|
|
444
|
+
"showscale": False,
|
|
445
|
+
"reversescale": True,
|
|
446
|
+
},
|
|
447
|
+
text=[f"{cv:.2f}" for cv in cv_values],
|
|
448
|
+
textposition="auto",
|
|
449
|
+
)
|
|
450
|
+
)
|
|
451
|
+
fig_cv.update_layout(
|
|
452
|
+
title="Feature Stability (Coefficient of Variation - Lower is Better)",
|
|
453
|
+
xaxis_title="Coefficient of Variation",
|
|
454
|
+
yaxis_title="Feature",
|
|
455
|
+
yaxis={"autorange": "reversed"},
|
|
456
|
+
template=get_theme(self.theme)["template"],
|
|
457
|
+
height=max(400, len(consensus_ranking) * 40),
|
|
458
|
+
margin={"l": 150, "r": 50, "t": 80, "b": 80},
|
|
459
|
+
)
|
|
460
|
+
html.append(fig_cv.to_html(include_plotlyjs=False, div_id="plot-cv"))
|
|
461
|
+
|
|
462
|
+
return "".join(html)
|
|
463
|
+
|
|
464
|
+
def _create_feature_details_tab(self, viz_data: ImportanceVizData) -> str:
|
|
465
|
+
"""Create Feature Details tab with searchable, filterable table."""
|
|
466
|
+
summary = viz_data["summary"]
|
|
467
|
+
per_feature = viz_data["per_feature"]
|
|
468
|
+
per_method = viz_data["per_method"]
|
|
469
|
+
|
|
470
|
+
methods = list(per_method.keys())
|
|
471
|
+
|
|
472
|
+
html = ["<h2>Feature Details</h2>"]
|
|
473
|
+
|
|
474
|
+
# Search box
|
|
475
|
+
html.append("""
|
|
476
|
+
<div style="margin: 20px 0;">
|
|
477
|
+
<input type="text" id="feature-search" placeholder="Type to filter features..."
|
|
478
|
+
style="width: 100%; padding: 10px; font-size: 16px; border: 1px solid #ccc; border-radius: 4px;">
|
|
479
|
+
</div>
|
|
480
|
+
""")
|
|
481
|
+
|
|
482
|
+
# Build detailed feature table
|
|
483
|
+
table_rows = []
|
|
484
|
+
for rank, feature_name in enumerate(summary["consensus_ranking"], start=1):
|
|
485
|
+
feat_data = per_feature[feature_name]
|
|
486
|
+
|
|
487
|
+
# Determine if this row should be highlighted for low agreement
|
|
488
|
+
agreement_class = " low-agreement" if feat_data["agreement_level"] == "low" else ""
|
|
489
|
+
|
|
490
|
+
# Build row
|
|
491
|
+
row_cells = [
|
|
492
|
+
f'<td style="font-weight: 600;">{rank}</td>',
|
|
493
|
+
f'<td style="font-weight: 500;">{feature_name}</td>',
|
|
494
|
+
f"<td>{feat_data['consensus_score'] * 100:.2f}%</td>",
|
|
495
|
+
]
|
|
496
|
+
|
|
497
|
+
# Add separate column for each method's rank
|
|
498
|
+
for m in methods:
|
|
499
|
+
method_rank = feat_data["method_ranks"].get(m, None)
|
|
500
|
+
if method_rank is not None:
|
|
501
|
+
row_cells.append(f'<td style="text-align: center;">{method_rank}</td>')
|
|
502
|
+
else:
|
|
503
|
+
row_cells.append('<td style="text-align: center; opacity: 0.3;">-</td>')
|
|
504
|
+
|
|
505
|
+
# Add agreement and stability
|
|
506
|
+
row_cells.extend(
|
|
507
|
+
[
|
|
508
|
+
f'<td><span class="badge badge-{feat_data["agreement_level"]}">{feat_data["agreement_level"]}</span></td>',
|
|
509
|
+
f"<td>{feat_data['stability_score']:.3f}</td>",
|
|
510
|
+
]
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
table_rows.append(f'<tr class="feature-row{agreement_class}">{"".join(row_cells)}</tr>')
|
|
514
|
+
|
|
515
|
+
# Build table header
|
|
516
|
+
method_header_cols = "".join(
|
|
517
|
+
[
|
|
518
|
+
f'<th title="{m.upper()} rank (lower is better)" style="text-align: center;">{m.upper()}<br/>Rank</th>'
|
|
519
|
+
for m in methods
|
|
520
|
+
]
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
html.append(f"""
|
|
524
|
+
<div style="overflow-x: auto;">
|
|
525
|
+
<table class="feature-table" id="feature-importance-table">
|
|
526
|
+
<thead>
|
|
527
|
+
<tr>
|
|
528
|
+
<th title="Consensus rank across all methods">Consensus<br/>Rank</th>
|
|
529
|
+
<th>Feature</th>
|
|
530
|
+
<th title="Average importance across all methods (normalized to %)">Consensus<br/>Score (%)</th>
|
|
531
|
+
{method_header_cols}
|
|
532
|
+
<th title="How well methods agree on this feature's importance (high/medium/low)">Agreement</th>
|
|
533
|
+
<th title="Consistency of rank across resampling (1.0 = perfectly stable)">Stability</th>
|
|
534
|
+
</tr>
|
|
535
|
+
</thead>
|
|
536
|
+
<tbody>
|
|
537
|
+
{"".join(table_rows)}
|
|
538
|
+
</tbody>
|
|
539
|
+
</table>
|
|
540
|
+
</div>
|
|
541
|
+
<p style="font-size: 0.85em; opacity: 0.7; margin-top: 10px;">
|
|
542
|
+
💡 <strong>Tip:</strong> Click column headers to sort the table. Use the search box above to filter features.
|
|
543
|
+
Rows highlighted in orange indicate low method agreement.
|
|
544
|
+
</p>
|
|
545
|
+
|
|
546
|
+
<script>
|
|
547
|
+
// Real-time search filtering
|
|
548
|
+
document.getElementById('feature-search').addEventListener('input', function(e) {{
|
|
549
|
+
const searchTerm = e.target.value.toLowerCase();
|
|
550
|
+
const table = document.getElementById('feature-importance-table');
|
|
551
|
+
|
|
552
|
+
// Guard against missing table
|
|
553
|
+
if (!table) return;
|
|
554
|
+
|
|
555
|
+
const tbody = table.getElementsByTagName('tbody')[0];
|
|
556
|
+
if (!tbody) return;
|
|
557
|
+
|
|
558
|
+
const rows = tbody.getElementsByTagName('tr');
|
|
559
|
+
|
|
560
|
+
for (let row of rows) {{
|
|
561
|
+
// Guard against malformed rows
|
|
562
|
+
if (!row.cells || row.cells.length < 2) continue;
|
|
563
|
+
|
|
564
|
+
const featureName = row.cells[1].textContent.toLowerCase();
|
|
565
|
+
row.style.display = featureName.includes(searchTerm) ? '' : 'none';
|
|
566
|
+
}}
|
|
567
|
+
}});
|
|
568
|
+
</script>
|
|
569
|
+
""")
|
|
570
|
+
|
|
571
|
+
return "".join(html)
|
|
572
|
+
|
|
573
|
+
def _create_interactions_tab(
|
|
574
|
+
self, viz_data: ImportanceVizData, interaction_results: dict[str, Any]
|
|
575
|
+
) -> str:
|
|
576
|
+
"""Create Interactions tab with adaptive display based on feature count."""
|
|
577
|
+
summary = viz_data["summary"]
|
|
578
|
+
n_features = summary["n_features"]
|
|
579
|
+
|
|
580
|
+
html = ["<h2>Feature Interactions</h2>"]
|
|
581
|
+
html.append("""
|
|
582
|
+
<p style="margin: 20px 0; font-style: italic; opacity: 0.8;">
|
|
583
|
+
SHAP interaction values show how feature contributions change based on other features.
|
|
584
|
+
Strong interactions suggest non-linear relationships and feature dependencies.
|
|
585
|
+
</p>
|
|
586
|
+
""")
|
|
587
|
+
|
|
588
|
+
# Extract interaction viz data
|
|
589
|
+
interaction_viz_data = extract_interaction_viz_data(
|
|
590
|
+
interaction_results,
|
|
591
|
+
importance_results={"consensus_ranking": summary["consensus_ranking"]},
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
inter_summary = interaction_viz_data["summary"]
|
|
595
|
+
interaction_viz_data["per_feature"]
|
|
596
|
+
network_data = interaction_viz_data["network_graph"]
|
|
597
|
+
|
|
598
|
+
# Top interaction info
|
|
599
|
+
strongest_pair = inter_summary["strongest_pair"]
|
|
600
|
+
strongest_value = inter_summary["strongest_interaction"]
|
|
601
|
+
|
|
602
|
+
html.append(f"""
|
|
603
|
+
<div class="insights-panel">
|
|
604
|
+
<h3>Top Interaction</h3>
|
|
605
|
+
<p><strong>{strongest_pair[0]}</strong> ↔ <strong>{strongest_pair[1]}</strong>: {strongest_value:.4f}</p>
|
|
606
|
+
<p style="margin-top: 10px;">Most interactive feature: <strong>{inter_summary["most_interactive_feature"]}</strong>
|
|
607
|
+
(total strength: {inter_summary["max_total_interaction"]:.4f})</p>
|
|
608
|
+
</div>
|
|
609
|
+
""")
|
|
610
|
+
|
|
611
|
+
# Adaptive display based on feature count
|
|
612
|
+
edges = network_data["edges"][: self.INTERACTION_TOP_EDGES]
|
|
613
|
+
|
|
614
|
+
if n_features > self.INTERACTION_NETWORK_THRESHOLD:
|
|
615
|
+
# Too many features - show top pairs table instead of network
|
|
616
|
+
html.append(f"""
|
|
617
|
+
<div class="section-description">
|
|
618
|
+
<p><strong>Note:</strong> With {n_features} features, a network visualization would be unreadable.
|
|
619
|
+
Showing top 10 interaction pairs instead.</p>
|
|
620
|
+
</div>
|
|
621
|
+
""")
|
|
622
|
+
|
|
623
|
+
# Top interaction pairs table
|
|
624
|
+
html.append("""
|
|
625
|
+
<h3 style="margin-top: 30px;">Top 10 Interaction Pairs</h3>
|
|
626
|
+
<div style="overflow-x: auto;">
|
|
627
|
+
<table class="feature-table">
|
|
628
|
+
<thead>
|
|
629
|
+
<tr>
|
|
630
|
+
<th>Rank</th>
|
|
631
|
+
<th>Feature A</th>
|
|
632
|
+
<th>Feature B</th>
|
|
633
|
+
<th>Interaction Strength</th>
|
|
634
|
+
</tr>
|
|
635
|
+
</thead>
|
|
636
|
+
<tbody>
|
|
637
|
+
""")
|
|
638
|
+
|
|
639
|
+
for i, edge in enumerate(edges[:10], start=1):
|
|
640
|
+
html.append(f"""
|
|
641
|
+
<tr>
|
|
642
|
+
<td>{i}</td>
|
|
643
|
+
<td><strong>{edge["source"]}</strong></td>
|
|
644
|
+
<td><strong>{edge["target"]}</strong></td>
|
|
645
|
+
<td>{edge["abs_weight"]:.4f}</td>
|
|
646
|
+
</tr>
|
|
647
|
+
""")
|
|
648
|
+
|
|
649
|
+
html.append("</tbody></table></div>")
|
|
650
|
+
|
|
651
|
+
# Bar chart of top pairs
|
|
652
|
+
pair_labels = [f"{edge['source']} ↔ {edge['target']}" for edge in edges[:10]]
|
|
653
|
+
pair_strengths = [edge["abs_weight"] for edge in edges[:10]]
|
|
654
|
+
|
|
655
|
+
fig_pairs = go.Figure()
|
|
656
|
+
fig_pairs.add_trace(
|
|
657
|
+
go.Bar(
|
|
658
|
+
y=pair_labels,
|
|
659
|
+
x=pair_strengths,
|
|
660
|
+
orientation="h",
|
|
661
|
+
marker={"color": pair_strengths, "colorscale": "Viridis", "showscale": False},
|
|
662
|
+
text=[f"{s:.4f}" for s in pair_strengths],
|
|
663
|
+
textposition="auto",
|
|
664
|
+
)
|
|
665
|
+
)
|
|
666
|
+
fig_pairs.update_layout(
|
|
667
|
+
title="Top 10 Feature Interaction Pairs",
|
|
668
|
+
xaxis_title="Interaction Strength",
|
|
669
|
+
yaxis_title="Feature Pair",
|
|
670
|
+
yaxis={"autorange": "reversed"},
|
|
671
|
+
template=get_theme(self.theme)["template"],
|
|
672
|
+
height=500,
|
|
673
|
+
)
|
|
674
|
+
html.append(fig_pairs.to_html(include_plotlyjs=False, div_id="plot-top-pairs"))
|
|
675
|
+
|
|
676
|
+
else:
|
|
677
|
+
# Few features - show network + matrix
|
|
678
|
+
nodes = network_data["nodes"]
|
|
679
|
+
|
|
680
|
+
if edges:
|
|
681
|
+
# Create network visualization
|
|
682
|
+
n_nodes = len(nodes)
|
|
683
|
+
angles = [2 * np.pi * i / n_nodes for i in range(n_nodes)]
|
|
684
|
+
node_x = [np.cos(angle) for angle in angles]
|
|
685
|
+
node_y = [np.sin(angle) for angle in angles]
|
|
686
|
+
|
|
687
|
+
node_positions = {
|
|
688
|
+
node["id"]: (node_x[i], node_y[i]) for i, node in enumerate(nodes)
|
|
689
|
+
}
|
|
690
|
+
node_importances = {node["id"]: node["importance"] for node in nodes}
|
|
691
|
+
|
|
692
|
+
# Create edge traces
|
|
693
|
+
edge_traces = []
|
|
694
|
+
for edge in edges:
|
|
695
|
+
x0, y0 = node_positions[edge["source"]]
|
|
696
|
+
x1, y1 = node_positions[edge["target"]]
|
|
697
|
+
|
|
698
|
+
edge_traces.append(
|
|
699
|
+
go.Scatter(
|
|
700
|
+
x=[x0, x1, None],
|
|
701
|
+
y=[y0, y1, None],
|
|
702
|
+
mode="lines",
|
|
703
|
+
line={
|
|
704
|
+
"width": edge["abs_weight"] * 10,
|
|
705
|
+
"color": "rgba(125,125,125,0.3)",
|
|
706
|
+
},
|
|
707
|
+
hoverinfo="none",
|
|
708
|
+
showlegend=False,
|
|
709
|
+
)
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
# Create node trace
|
|
713
|
+
node_trace = go.Scatter(
|
|
714
|
+
x=node_x,
|
|
715
|
+
y=node_y,
|
|
716
|
+
mode="markers+text",
|
|
717
|
+
marker={
|
|
718
|
+
"size": [node_importances.get(node["id"], 0.1) * 100 for node in nodes],
|
|
719
|
+
"color": [node["total_interaction"] for node in nodes],
|
|
720
|
+
"colorscale": "Viridis",
|
|
721
|
+
"showscale": True,
|
|
722
|
+
"colorbar": {"title": "Total<br>Interaction"},
|
|
723
|
+
"line": {"width": 2, "color": "white"},
|
|
724
|
+
},
|
|
725
|
+
text=[node["label"] for node in nodes],
|
|
726
|
+
textposition="top center",
|
|
727
|
+
textfont={"size": 10},
|
|
728
|
+
hovertemplate="<b>%{text}</b><br>Total Interaction: %{marker.color:.3f}<extra></extra>",
|
|
729
|
+
showlegend=False,
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
fig_network = go.Figure(data=edge_traces + [node_trace])
|
|
733
|
+
fig_network.update_layout(
|
|
734
|
+
title=f"Feature Interaction Network (Top {len(edges)} Strongest Interactions)",
|
|
735
|
+
template=get_theme(self.theme)["template"],
|
|
736
|
+
height=600,
|
|
737
|
+
xaxis={"showgrid": False, "zeroline": False, "showticklabels": False},
|
|
738
|
+
yaxis={"showgrid": False, "zeroline": False, "showticklabels": False},
|
|
739
|
+
hovermode="closest",
|
|
740
|
+
)
|
|
741
|
+
html.append(fig_network.to_html(include_plotlyjs=False, div_id="plot-network"))
|
|
742
|
+
|
|
743
|
+
# Interaction matrix heatmap
|
|
744
|
+
matrix_data = interaction_viz_data["interaction_matrix"]
|
|
745
|
+
matrix = matrix_data["matrix"]
|
|
746
|
+
|
|
747
|
+
# Show up to INTERACTION_MATRIX_SIZE features for readability
|
|
748
|
+
n_features_display = min(self.INTERACTION_MATRIX_SIZE, len(matrix_data["features"]))
|
|
749
|
+
features_for_matrix = matrix_data["features"][:n_features_display]
|
|
750
|
+
|
|
751
|
+
# Extract the subset of the matrix we need
|
|
752
|
+
matrix_subset = [
|
|
753
|
+
[matrix[i][j] for j in range(n_features_display)] for i in range(n_features_display)
|
|
754
|
+
]
|
|
755
|
+
|
|
756
|
+
fig_matrix = go.Figure(
|
|
757
|
+
data=go.Heatmap(
|
|
758
|
+
z=matrix_subset,
|
|
759
|
+
x=features_for_matrix,
|
|
760
|
+
y=features_for_matrix,
|
|
761
|
+
colorscale="RdBu",
|
|
762
|
+
zmid=0,
|
|
763
|
+
colorbar={"title": "Interaction<br>Strength"},
|
|
764
|
+
)
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
fig_matrix.update_layout(
|
|
768
|
+
title=f"Interaction Matrix Heatmap (Top {len(features_for_matrix)} Features)",
|
|
769
|
+
template=get_theme(self.theme)["template"],
|
|
770
|
+
height=600,
|
|
771
|
+
xaxis={"tickangle": -45},
|
|
772
|
+
)
|
|
773
|
+
html.append(fig_matrix.to_html(include_plotlyjs=False, div_id="plot-matrix"))
|
|
774
|
+
|
|
775
|
+
return "".join(html)
|
|
776
|
+
|
|
777
|
+
def _compose_html(self) -> str:
|
|
778
|
+
"""Compose final HTML document."""
|
|
779
|
+
return f"""
|
|
780
|
+
<!DOCTYPE html>
|
|
781
|
+
<html>
|
|
782
|
+
<head>
|
|
783
|
+
<meta charset="utf-8">
|
|
784
|
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
|
785
|
+
<title>{self.title}</title>
|
|
786
|
+
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
|
|
787
|
+
{self._get_base_styles()}
|
|
788
|
+
</head>
|
|
789
|
+
<body>
|
|
790
|
+
{self._build_header()}
|
|
791
|
+
{self._build_navigation()}
|
|
792
|
+
<div class="dashboard-container">
|
|
793
|
+
{self._build_sections()}
|
|
794
|
+
</div>
|
|
795
|
+
{self._get_base_scripts()}
|
|
796
|
+
</body>
|
|
797
|
+
</html>
|
|
798
|
+
"""
|
|
799
|
+
|
|
800
|
+
|
|
801
|
+
__all__ = ["FeatureImportanceDashboard"]
|