ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,888 @@
|
|
|
1
|
+
"""Feature importance visualization functions.
|
|
2
|
+
|
|
3
|
+
This module provides functions for visualizing ML feature importance analysis results
|
|
4
|
+
from analyze_ml_importance() and related functions.
|
|
5
|
+
|
|
6
|
+
All plot functions follow the standard API defined in docs/plot_api_standards.md:
|
|
7
|
+
- Consume results dicts from analyze_*() functions
|
|
8
|
+
- Return plotly.graph_objects.Figure instances
|
|
9
|
+
- Support theme customization via global or per-plot settings
|
|
10
|
+
- Use keyword-only arguments (after results)
|
|
11
|
+
- Provide comprehensive hover information and interactivity
|
|
12
|
+
|
|
13
|
+
Example workflow:
|
|
14
|
+
>>> from ml4t.diagnostic.evaluation import analyze_ml_importance
|
|
15
|
+
>>> from ml4t.diagnostic.visualization import plot_importance_bar, set_plot_theme
|
|
16
|
+
>>>
|
|
17
|
+
>>> # Analyze feature importance
|
|
18
|
+
>>> results = analyze_ml_importance(model, X, y, methods=["mdi", "pfi"])
|
|
19
|
+
>>>
|
|
20
|
+
>>> # Set global theme
|
|
21
|
+
>>> set_plot_theme("dark")
|
|
22
|
+
>>>
|
|
23
|
+
>>> # Create visualizations
|
|
24
|
+
>>> fig_bar = plot_importance_bar(results, top_n=15)
|
|
25
|
+
>>> fig_heatmap = plot_importance_heatmap(results)
|
|
26
|
+
>>> fig_dist = plot_importance_distribution(results)
|
|
27
|
+
>>> fig_summary = plot_importance_summary(results)
|
|
28
|
+
>>>
|
|
29
|
+
>>> # Display or save
|
|
30
|
+
>>> fig_bar.show()
|
|
31
|
+
>>> fig_summary.write_html("importance_report.html")
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
from typing import Any
|
|
35
|
+
|
|
36
|
+
import numpy as np
|
|
37
|
+
import plotly.graph_objects as go
|
|
38
|
+
from plotly.subplots import make_subplots
|
|
39
|
+
|
|
40
|
+
from ml4t.diagnostic.visualization.core import (
|
|
41
|
+
apply_responsive_layout,
|
|
42
|
+
format_number,
|
|
43
|
+
get_color_scheme,
|
|
44
|
+
get_colorscale,
|
|
45
|
+
get_theme_config,
|
|
46
|
+
validate_plot_results,
|
|
47
|
+
validate_positive_int,
|
|
48
|
+
validate_theme,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
__all__ = [
|
|
52
|
+
"plot_importance_bar",
|
|
53
|
+
"plot_importance_heatmap",
|
|
54
|
+
"plot_importance_distribution",
|
|
55
|
+
"plot_importance_summary",
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def plot_importance_bar(
|
|
60
|
+
results: dict[str, Any],
|
|
61
|
+
*,
|
|
62
|
+
title: str | None = None,
|
|
63
|
+
top_n: int | None = 20,
|
|
64
|
+
theme: str | None = None,
|
|
65
|
+
color_scheme: str | None = None,
|
|
66
|
+
width: int | None = None,
|
|
67
|
+
height: int | None = None,
|
|
68
|
+
show_values: bool = True,
|
|
69
|
+
) -> go.Figure:
|
|
70
|
+
"""Plot horizontal bar chart of consensus feature importance rankings.
|
|
71
|
+
|
|
72
|
+
Creates an interactive bar chart showing features ranked by consensus importance
|
|
73
|
+
(average rank across all methods). Bars are color-coded by importance score using
|
|
74
|
+
a continuous colorscale.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
results : dict[str, Any]
|
|
79
|
+
Results from analyze_ml_importance() containing:
|
|
80
|
+
- "consensus_ranking": list[str] - Features in order of importance
|
|
81
|
+
- "method_results": dict - Individual method results with importances
|
|
82
|
+
title : str | None, optional
|
|
83
|
+
Plot title. If None, uses "Feature Importance - Consensus Ranking"
|
|
84
|
+
top_n : int | None, optional
|
|
85
|
+
Number of top features to display. If None, shows all features.
|
|
86
|
+
Default is 20 to avoid overcrowding.
|
|
87
|
+
theme : str | None, optional
|
|
88
|
+
Theme name ("default", "dark", "print", "presentation").
|
|
89
|
+
If None, uses current global theme.
|
|
90
|
+
color_scheme : str | None, optional
|
|
91
|
+
Color scheme for bars. If None, uses "viridis".
|
|
92
|
+
Recommended: "viridis", "cividis", "plasma", "blues", "greens"
|
|
93
|
+
width : int | None, optional
|
|
94
|
+
Figure width in pixels. If None, uses theme default (typically 1000).
|
|
95
|
+
height : int | None, optional
|
|
96
|
+
Figure height in pixels. If None, auto-sizes based on feature count
|
|
97
|
+
(25px per feature + 100px padding).
|
|
98
|
+
show_values : bool, optional
|
|
99
|
+
Whether to show importance values on bars. Default is True.
|
|
100
|
+
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
go.Figure
|
|
104
|
+
Interactive Plotly figure with:
|
|
105
|
+
- Horizontal bars sorted by consensus importance
|
|
106
|
+
- Continuous color gradient indicating importance scores
|
|
107
|
+
- Hover info showing exact importance values
|
|
108
|
+
- Responsive layout for different screen sizes
|
|
109
|
+
|
|
110
|
+
Raises
|
|
111
|
+
------
|
|
112
|
+
ValueError
|
|
113
|
+
If results dict is missing required keys or has invalid structure.
|
|
114
|
+
TypeError
|
|
115
|
+
If parameters have incorrect types.
|
|
116
|
+
|
|
117
|
+
Examples
|
|
118
|
+
--------
|
|
119
|
+
>>> from ml4t.diagnostic.evaluation import analyze_ml_importance
|
|
120
|
+
>>> from ml4t.diagnostic.visualization import plot_importance_bar
|
|
121
|
+
>>>
|
|
122
|
+
>>> # Analyze importance
|
|
123
|
+
>>> results = analyze_ml_importance(model, X, y)
|
|
124
|
+
>>>
|
|
125
|
+
>>> # Plot top 10 features
|
|
126
|
+
>>> fig = plot_importance_bar(results, top_n=10)
|
|
127
|
+
>>> fig.show()
|
|
128
|
+
>>>
|
|
129
|
+
>>> # Custom styling for print
|
|
130
|
+
>>> fig = plot_importance_bar(
|
|
131
|
+
... results,
|
|
132
|
+
... title="Key Predictive Features",
|
|
133
|
+
... top_n=15,
|
|
134
|
+
... theme="print",
|
|
135
|
+
... color_scheme="blues",
|
|
136
|
+
... height=600
|
|
137
|
+
... )
|
|
138
|
+
>>> fig.write_image("feature_importance.pdf")
|
|
139
|
+
|
|
140
|
+
Notes
|
|
141
|
+
-----
|
|
142
|
+
- Importance scores are computed as the mean importance across all methods
|
|
143
|
+
- Features are ranked by consensus (average rank), not absolute importance
|
|
144
|
+
- Use top_n to focus on most important features and improve readability
|
|
145
|
+
- For very long feature names, consider increasing width parameter
|
|
146
|
+
"""
|
|
147
|
+
# Validate inputs
|
|
148
|
+
validate_plot_results(
|
|
149
|
+
results,
|
|
150
|
+
required_keys=["consensus_ranking", "method_results"],
|
|
151
|
+
function_name="plot_importance_bar",
|
|
152
|
+
)
|
|
153
|
+
theme = validate_theme(theme)
|
|
154
|
+
if top_n is not None:
|
|
155
|
+
validate_positive_int(top_n, "top_n")
|
|
156
|
+
|
|
157
|
+
# Note: color_scheme validation happens in get_colorscale()
|
|
158
|
+
|
|
159
|
+
# Extract data
|
|
160
|
+
all_features = results["consensus_ranking"]
|
|
161
|
+
features = all_features[:top_n] if top_n is not None else all_features
|
|
162
|
+
|
|
163
|
+
# Calculate average importance across methods for each feature
|
|
164
|
+
method_results = results["method_results"]
|
|
165
|
+
importance_scores = []
|
|
166
|
+
|
|
167
|
+
for feat in features:
|
|
168
|
+
scores = []
|
|
169
|
+
for method_name, method_result in method_results.items():
|
|
170
|
+
# Get feature importances from method result
|
|
171
|
+
if method_name == "pfi":
|
|
172
|
+
# PFI uses importances_mean
|
|
173
|
+
importances = method_result["importances_mean"]
|
|
174
|
+
else:
|
|
175
|
+
# MDI, MDA, SHAP use importances
|
|
176
|
+
importances = method_result["importances"]
|
|
177
|
+
|
|
178
|
+
# Get feature names for this method
|
|
179
|
+
method_features = method_result["feature_names"]
|
|
180
|
+
|
|
181
|
+
# Find this feature's importance
|
|
182
|
+
if feat in method_features:
|
|
183
|
+
idx = method_features.index(feat)
|
|
184
|
+
scores.append(importances[idx])
|
|
185
|
+
|
|
186
|
+
# Average importance across methods
|
|
187
|
+
if scores:
|
|
188
|
+
importance_scores.append(float(np.mean(scores)))
|
|
189
|
+
else:
|
|
190
|
+
importance_scores.append(0.0)
|
|
191
|
+
|
|
192
|
+
# Get theme configuration
|
|
193
|
+
theme_config = get_theme_config(theme)
|
|
194
|
+
|
|
195
|
+
# Get colors
|
|
196
|
+
colors = get_colorscale(color_scheme or "viridis")
|
|
197
|
+
|
|
198
|
+
# Create figure
|
|
199
|
+
fig = go.Figure()
|
|
200
|
+
|
|
201
|
+
# Add bar trace
|
|
202
|
+
fig.add_trace(
|
|
203
|
+
go.Bar(
|
|
204
|
+
x=importance_scores,
|
|
205
|
+
y=features,
|
|
206
|
+
orientation="h",
|
|
207
|
+
marker={
|
|
208
|
+
"color": importance_scores,
|
|
209
|
+
"colorscale": colors,
|
|
210
|
+
"showscale": True,
|
|
211
|
+
"colorbar": {
|
|
212
|
+
"title": "Importance",
|
|
213
|
+
"tickformat": ".3f",
|
|
214
|
+
},
|
|
215
|
+
},
|
|
216
|
+
text=[format_number(v, precision=3) for v in importance_scores]
|
|
217
|
+
if show_values
|
|
218
|
+
else None,
|
|
219
|
+
textposition="outside",
|
|
220
|
+
hovertemplate="<b>%{y}</b><br>Importance: %{x:.4f}<extra></extra>",
|
|
221
|
+
)
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Update layout
|
|
225
|
+
fig.update_layout(
|
|
226
|
+
title=title or "Feature Importance - Consensus Ranking",
|
|
227
|
+
xaxis_title="Consensus Importance Score",
|
|
228
|
+
yaxis_title="Features",
|
|
229
|
+
**theme_config["layout"],
|
|
230
|
+
width=width or 1000,
|
|
231
|
+
height=height or max(400, len(features) * 25 + 100),
|
|
232
|
+
showlegend=False,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Apply responsive layout
|
|
236
|
+
apply_responsive_layout(fig)
|
|
237
|
+
|
|
238
|
+
return fig
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def plot_importance_heatmap(
|
|
242
|
+
results: dict[str, Any],
|
|
243
|
+
*,
|
|
244
|
+
title: str | None = None,
|
|
245
|
+
theme: str | None = None,
|
|
246
|
+
color_scheme: str | None = None,
|
|
247
|
+
width: int | None = None,
|
|
248
|
+
height: int | None = None,
|
|
249
|
+
show_values: bool = True,
|
|
250
|
+
) -> go.Figure:
|
|
251
|
+
"""Plot heatmap showing correlation between importance ranking methods.
|
|
252
|
+
|
|
253
|
+
Creates a symmetric correlation matrix showing Spearman rank correlations between
|
|
254
|
+
different feature importance methods (MDI, PFI, MDA, SHAP). High correlations
|
|
255
|
+
indicate method agreement; low correlations suggest different aspects being measured.
|
|
256
|
+
|
|
257
|
+
Parameters
|
|
258
|
+
----------
|
|
259
|
+
results : dict[str, Any]
|
|
260
|
+
Results from analyze_ml_importance() containing:
|
|
261
|
+
- "method_agreement": dict - Pairwise Spearman correlations
|
|
262
|
+
- "methods_run": list[str] - Names of methods that ran successfully
|
|
263
|
+
title : str | None, optional
|
|
264
|
+
Plot title. If None, uses "Method Agreement - Ranking Correlations"
|
|
265
|
+
theme : str | None, optional
|
|
266
|
+
Theme name ("default", "dark", "print", "presentation").
|
|
267
|
+
If None, uses current global theme.
|
|
268
|
+
color_scheme : str | None, optional
|
|
269
|
+
Diverging color scheme for correlation values. If None, uses "rdbu".
|
|
270
|
+
Recommended: "rdbu", "rdylgn", "brbg", "blues_oranges"
|
|
271
|
+
width : int | None, optional
|
|
272
|
+
Figure width in pixels. If None, uses 800.
|
|
273
|
+
height : int | None, optional
|
|
274
|
+
Figure height in pixels. If None, uses 800.
|
|
275
|
+
show_values : bool, optional
|
|
276
|
+
Whether to show correlation values in cells. Default is True.
|
|
277
|
+
|
|
278
|
+
Returns
|
|
279
|
+
-------
|
|
280
|
+
go.Figure
|
|
281
|
+
Interactive Plotly heatmap with:
|
|
282
|
+
- Symmetric correlation matrix
|
|
283
|
+
- Diverging colorscale (red = negative, blue = positive)
|
|
284
|
+
- Annotated cells with correlation coefficients
|
|
285
|
+
- Hover showing method pairs and correlation
|
|
286
|
+
|
|
287
|
+
Raises
|
|
288
|
+
------
|
|
289
|
+
ValueError
|
|
290
|
+
If results dict is missing required keys or has invalid structure.
|
|
291
|
+
If fewer than 2 methods were run (can't compute correlations).
|
|
292
|
+
TypeError
|
|
293
|
+
If parameters have incorrect types.
|
|
294
|
+
|
|
295
|
+
Examples
|
|
296
|
+
--------
|
|
297
|
+
>>> from ml4t.diagnostic.evaluation import analyze_ml_importance
|
|
298
|
+
>>> from ml4t.diagnostic.visualization import plot_importance_heatmap
|
|
299
|
+
>>>
|
|
300
|
+
>>> # Analyze with multiple methods
|
|
301
|
+
>>> results = analyze_ml_importance(
|
|
302
|
+
... model, X, y,
|
|
303
|
+
... methods=["mdi", "pfi", "shap"]
|
|
304
|
+
... )
|
|
305
|
+
>>>
|
|
306
|
+
>>> # Plot method agreement
|
|
307
|
+
>>> fig = plot_importance_heatmap(results)
|
|
308
|
+
>>> fig.show()
|
|
309
|
+
>>>
|
|
310
|
+
>>> # Custom styling
|
|
311
|
+
>>> fig = plot_importance_heatmap(
|
|
312
|
+
... results,
|
|
313
|
+
... title="Feature Ranking Method Correlations",
|
|
314
|
+
... theme="presentation",
|
|
315
|
+
... color_scheme="rdylgn"
|
|
316
|
+
... )
|
|
317
|
+
|
|
318
|
+
Notes
|
|
319
|
+
-----
|
|
320
|
+
- Correlations range from -1 (perfect disagreement) to +1 (perfect agreement)
|
|
321
|
+
- High correlations (>0.7) indicate methods are measuring similar aspects
|
|
322
|
+
- Low correlations (<0.5) suggest methods capture different information
|
|
323
|
+
- Diagonal is always 1.0 (perfect self-correlation)
|
|
324
|
+
- Matrix is symmetric (corr(A,B) = corr(B,A))
|
|
325
|
+
"""
|
|
326
|
+
# Validate inputs
|
|
327
|
+
validate_plot_results(
|
|
328
|
+
results,
|
|
329
|
+
required_keys=["method_agreement", "methods_run"],
|
|
330
|
+
function_name="plot_importance_heatmap",
|
|
331
|
+
)
|
|
332
|
+
theme = validate_theme(theme)
|
|
333
|
+
|
|
334
|
+
# Note: color_scheme validation happens in get_colorscale()
|
|
335
|
+
|
|
336
|
+
methods = results["methods_run"]
|
|
337
|
+
if len(methods) < 2:
|
|
338
|
+
raise ValueError(f"plot_importance_heatmap requires at least 2 methods, got {len(methods)}")
|
|
339
|
+
|
|
340
|
+
# Build correlation matrix from pairwise comparisons
|
|
341
|
+
n_methods = len(methods)
|
|
342
|
+
correlation_matrix = np.eye(n_methods) # Diagonal = 1.0
|
|
343
|
+
|
|
344
|
+
method_agreement = results["method_agreement"]
|
|
345
|
+
|
|
346
|
+
for i, method1 in enumerate(methods):
|
|
347
|
+
for j, method2 in enumerate(methods):
|
|
348
|
+
if i < j: # Upper triangle
|
|
349
|
+
key1 = f"{method1}_vs_{method2}"
|
|
350
|
+
key2 = f"{method2}_vs_{method1}"
|
|
351
|
+
|
|
352
|
+
# Try both key orders
|
|
353
|
+
if key1 in method_agreement:
|
|
354
|
+
corr = method_agreement[key1]
|
|
355
|
+
elif key2 in method_agreement:
|
|
356
|
+
corr = method_agreement[key2]
|
|
357
|
+
else:
|
|
358
|
+
# Shouldn't happen, but handle gracefully
|
|
359
|
+
corr = 0.0
|
|
360
|
+
|
|
361
|
+
correlation_matrix[i, j] = corr
|
|
362
|
+
correlation_matrix[j, i] = corr # Symmetric
|
|
363
|
+
|
|
364
|
+
# Get theme configuration
|
|
365
|
+
theme_config = get_theme_config(theme)
|
|
366
|
+
|
|
367
|
+
# Get colors (diverging colorscale for correlations)
|
|
368
|
+
colors = get_colorscale(color_scheme or "rdbu")
|
|
369
|
+
|
|
370
|
+
# Create figure
|
|
371
|
+
fig = go.Figure()
|
|
372
|
+
|
|
373
|
+
# Create hover text
|
|
374
|
+
hover_text = []
|
|
375
|
+
for i, method1 in enumerate(methods):
|
|
376
|
+
row = []
|
|
377
|
+
for j, method2 in enumerate(methods):
|
|
378
|
+
corr = correlation_matrix[i, j]
|
|
379
|
+
row.append(
|
|
380
|
+
f"<b>{method1.upper()}</b> vs <b>{method2.upper()}</b><br>Correlation: {corr:.3f}"
|
|
381
|
+
)
|
|
382
|
+
hover_text.append(row)
|
|
383
|
+
|
|
384
|
+
# Add heatmap trace
|
|
385
|
+
fig.add_trace(
|
|
386
|
+
go.Heatmap(
|
|
387
|
+
z=correlation_matrix,
|
|
388
|
+
x=[m.upper() for m in methods],
|
|
389
|
+
y=[m.upper() for m in methods],
|
|
390
|
+
colorscale=colors,
|
|
391
|
+
zmid=0, # Center diverging scale at 0
|
|
392
|
+
zmin=-1,
|
|
393
|
+
zmax=1,
|
|
394
|
+
colorbar={
|
|
395
|
+
"title": "Correlation",
|
|
396
|
+
"tickmode": "linear",
|
|
397
|
+
"tick0": -1,
|
|
398
|
+
"dtick": 0.5,
|
|
399
|
+
},
|
|
400
|
+
text=np.round(correlation_matrix, 3) if show_values else None,
|
|
401
|
+
texttemplate="%{text}" if show_values else None,
|
|
402
|
+
textfont={"size": 12},
|
|
403
|
+
hovertext=hover_text,
|
|
404
|
+
hovertemplate="%{hovertext}<extra></extra>",
|
|
405
|
+
)
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
# Update layout
|
|
409
|
+
fig.update_layout(
|
|
410
|
+
title=title or "Method Agreement - Ranking Correlations",
|
|
411
|
+
xaxis={
|
|
412
|
+
"title": "",
|
|
413
|
+
"side": "bottom",
|
|
414
|
+
},
|
|
415
|
+
yaxis={
|
|
416
|
+
"title": "",
|
|
417
|
+
"autorange": "reversed", # Top to bottom
|
|
418
|
+
},
|
|
419
|
+
**theme_config["layout"],
|
|
420
|
+
width=width or 800,
|
|
421
|
+
height=height or 800,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Apply responsive layout
|
|
425
|
+
apply_responsive_layout(fig)
|
|
426
|
+
|
|
427
|
+
return fig
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def plot_importance_distribution(
|
|
431
|
+
results: dict[str, Any],
|
|
432
|
+
*,
|
|
433
|
+
title: str | None = None,
|
|
434
|
+
method: str | None = None,
|
|
435
|
+
theme: str | None = None,
|
|
436
|
+
color_scheme: str | None = None,
|
|
437
|
+
width: int | None = None,
|
|
438
|
+
height: int | None = None,
|
|
439
|
+
bins: int = 30,
|
|
440
|
+
overlay: bool = False,
|
|
441
|
+
) -> go.Figure:
|
|
442
|
+
"""Plot distribution of feature importance scores across methods.
|
|
443
|
+
|
|
444
|
+
Creates histogram(s) showing the distribution of importance scores. Can either
|
|
445
|
+
overlay all methods in a single plot or show them separately in subplots.
|
|
446
|
+
Useful for understanding the spread and concentration of importance values.
|
|
447
|
+
|
|
448
|
+
Parameters
|
|
449
|
+
----------
|
|
450
|
+
results : dict[str, Any]
|
|
451
|
+
Results from analyze_ml_importance() containing:
|
|
452
|
+
- "method_results": dict - Individual method results with importances
|
|
453
|
+
- "methods_run": list[str] - Names of methods that ran successfully
|
|
454
|
+
title : str | None, optional
|
|
455
|
+
Plot title. If None, uses "Feature Importance Distribution"
|
|
456
|
+
method : str | None, optional
|
|
457
|
+
Show distribution for a single method only. If None, shows all methods.
|
|
458
|
+
Valid values: "mdi", "pfi", "mda", "shap" (must be in methods_run)
|
|
459
|
+
theme : str | None, optional
|
|
460
|
+
Theme name ("default", "dark", "print", "presentation").
|
|
461
|
+
If None, uses current global theme.
|
|
462
|
+
color_scheme : str | None, optional
|
|
463
|
+
Color scheme for histogram bars. If None, uses "set2".
|
|
464
|
+
Recommended: "set2", "set3", "pastel" for qualitative
|
|
465
|
+
width : int | None, optional
|
|
466
|
+
Figure width in pixels. If None, uses 1000.
|
|
467
|
+
height : int | None, optional
|
|
468
|
+
Figure height in pixels. If None, uses 600 (overlay) or 400 per method.
|
|
469
|
+
bins : int, optional
|
|
470
|
+
Number of histogram bins. Default is 30.
|
|
471
|
+
overlay : bool, optional
|
|
472
|
+
If True and method is None, overlay all methods in single plot.
|
|
473
|
+
If False and method is None, create subplot for each method.
|
|
474
|
+
Default is False (subplots).
|
|
475
|
+
|
|
476
|
+
Returns
|
|
477
|
+
-------
|
|
478
|
+
go.Figure
|
|
479
|
+
Interactive Plotly histogram with:
|
|
480
|
+
- Distribution of importance scores
|
|
481
|
+
- Optional multiple methods overlaid or in subplots
|
|
482
|
+
- Statistics annotations (mean, median, quartiles)
|
|
483
|
+
- Hover showing bin ranges and counts
|
|
484
|
+
|
|
485
|
+
Raises
|
|
486
|
+
------
|
|
487
|
+
ValueError
|
|
488
|
+
If results dict is missing required keys or has invalid structure.
|
|
489
|
+
If specified method was not run or doesn't exist.
|
|
490
|
+
TypeError
|
|
491
|
+
If parameters have incorrect types.
|
|
492
|
+
|
|
493
|
+
Examples
|
|
494
|
+
--------
|
|
495
|
+
>>> from ml4t.diagnostic.evaluation import analyze_ml_importance
|
|
496
|
+
>>> from ml4t.diagnostic.visualization import plot_importance_distribution
|
|
497
|
+
>>>
|
|
498
|
+
>>> # Analyze importance
|
|
499
|
+
>>> results = analyze_ml_importance(model, X, y)
|
|
500
|
+
>>>
|
|
501
|
+
>>> # Show all methods (subplots)
|
|
502
|
+
>>> fig = plot_importance_distribution(results)
|
|
503
|
+
>>> fig.show()
|
|
504
|
+
>>>
|
|
505
|
+
>>> # Overlay for comparison
|
|
506
|
+
>>> fig = plot_importance_distribution(results, overlay=True)
|
|
507
|
+
>>> fig.show()
|
|
508
|
+
>>>
|
|
509
|
+
>>> # Single method with custom bins
|
|
510
|
+
>>> fig = plot_importance_distribution(
|
|
511
|
+
... results,
|
|
512
|
+
... method="pfi",
|
|
513
|
+
... bins=50,
|
|
514
|
+
... theme="dark"
|
|
515
|
+
... )
|
|
516
|
+
|
|
517
|
+
Notes
|
|
518
|
+
-----
|
|
519
|
+
- Distributions reveal whether importance is concentrated or spread out
|
|
520
|
+
- Overlay mode is best for comparing 2-3 methods; use subplots for more
|
|
521
|
+
- Very skewed distributions may benefit from log scale (not implemented yet)
|
|
522
|
+
- Consider binning strategy for features with very different importance ranges
|
|
523
|
+
"""
|
|
524
|
+
# Validate inputs
|
|
525
|
+
validate_plot_results(
|
|
526
|
+
results,
|
|
527
|
+
required_keys=["method_results", "methods_run"],
|
|
528
|
+
function_name="plot_importance_distribution",
|
|
529
|
+
)
|
|
530
|
+
theme = validate_theme(theme)
|
|
531
|
+
validate_positive_int(bins, "bins")
|
|
532
|
+
|
|
533
|
+
# Note: color_scheme validation happens in get_color_scheme()
|
|
534
|
+
|
|
535
|
+
methods_run = results["methods_run"]
|
|
536
|
+
method_results = results["method_results"]
|
|
537
|
+
|
|
538
|
+
# Determine which methods to plot
|
|
539
|
+
if method is not None:
|
|
540
|
+
if method not in methods_run:
|
|
541
|
+
raise ValueError(
|
|
542
|
+
f"Method '{method}' not found in results. Available methods: {methods_run}"
|
|
543
|
+
)
|
|
544
|
+
methods_to_plot = [method]
|
|
545
|
+
else:
|
|
546
|
+
methods_to_plot = methods_run
|
|
547
|
+
|
|
548
|
+
# Get theme configuration
|
|
549
|
+
theme_config = get_theme_config(theme)
|
|
550
|
+
|
|
551
|
+
# Get colors (get full scheme and use first N colors)
|
|
552
|
+
color_list = get_color_scheme(color_scheme or "set2")
|
|
553
|
+
colors = (
|
|
554
|
+
color_list[: len(methods_to_plot)]
|
|
555
|
+
if len(methods_to_plot) <= len(color_list)
|
|
556
|
+
else color_list
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
# Extract importance scores for each method
|
|
560
|
+
method_scores = {}
|
|
561
|
+
for method_name in methods_to_plot:
|
|
562
|
+
result = method_results[method_name]
|
|
563
|
+
scores = result["importances_mean"] if method_name == "pfi" else result["importances"]
|
|
564
|
+
method_scores[method_name] = scores
|
|
565
|
+
|
|
566
|
+
# Create figure
|
|
567
|
+
if overlay or len(methods_to_plot) == 1:
|
|
568
|
+
# Single plot with overlaid histograms
|
|
569
|
+
fig = go.Figure()
|
|
570
|
+
|
|
571
|
+
for i, (method_name, scores) in enumerate(method_scores.items()):
|
|
572
|
+
fig.add_trace(
|
|
573
|
+
go.Histogram(
|
|
574
|
+
x=scores,
|
|
575
|
+
name=method_name.upper(),
|
|
576
|
+
nbinsx=bins,
|
|
577
|
+
marker_color=colors[i],
|
|
578
|
+
opacity=0.7 if overlay else 1.0,
|
|
579
|
+
hovertemplate=(
|
|
580
|
+
f"<b>{method_name.upper()}</b><br>Importance: %{{x:.4f}}<br>Count: %{{y}}<extra></extra>"
|
|
581
|
+
),
|
|
582
|
+
)
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
fig.update_layout(
|
|
586
|
+
title=title or "Feature Importance Distribution",
|
|
587
|
+
xaxis_title="Importance Score",
|
|
588
|
+
yaxis_title="Frequency",
|
|
589
|
+
barmode="overlay" if overlay else "stack",
|
|
590
|
+
**theme_config["layout"],
|
|
591
|
+
width=width or 1000,
|
|
592
|
+
height=height or 600,
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
else:
|
|
596
|
+
# Subplots for each method
|
|
597
|
+
n_methods = len(methods_to_plot)
|
|
598
|
+
fig = make_subplots(
|
|
599
|
+
rows=n_methods,
|
|
600
|
+
cols=1,
|
|
601
|
+
subplot_titles=[m.upper() for m in methods_to_plot],
|
|
602
|
+
vertical_spacing=0.1,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
for i, (method_name, scores) in enumerate(method_scores.items(), start=1):
|
|
606
|
+
fig.add_trace(
|
|
607
|
+
go.Histogram(
|
|
608
|
+
x=scores,
|
|
609
|
+
nbinsx=bins,
|
|
610
|
+
marker_color=colors[i - 1],
|
|
611
|
+
name=method_name.upper(),
|
|
612
|
+
showlegend=False,
|
|
613
|
+
hovertemplate=(
|
|
614
|
+
f"<b>{method_name.upper()}</b><br>Importance: %{{x:.4f}}<br>Count: %{{y}}<extra></extra>"
|
|
615
|
+
),
|
|
616
|
+
),
|
|
617
|
+
row=i,
|
|
618
|
+
col=1,
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
# Update subplot axes
|
|
622
|
+
fig.update_xaxes(title_text="Importance Score", row=i, col=1)
|
|
623
|
+
fig.update_yaxes(title_text="Frequency", row=i, col=1)
|
|
624
|
+
|
|
625
|
+
fig.update_layout(
|
|
626
|
+
title=title or "Feature Importance Distribution by Method",
|
|
627
|
+
**theme_config["layout"],
|
|
628
|
+
width=width or 1000,
|
|
629
|
+
height=height or (400 * n_methods),
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
# Apply responsive layout
|
|
633
|
+
apply_responsive_layout(fig)
|
|
634
|
+
|
|
635
|
+
return fig
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def plot_importance_summary(
|
|
639
|
+
results: dict[str, Any],
|
|
640
|
+
*,
|
|
641
|
+
title: str | None = None,
|
|
642
|
+
top_n: int = 15,
|
|
643
|
+
theme: str | None = None,
|
|
644
|
+
width: int | None = None,
|
|
645
|
+
height: int | None = None,
|
|
646
|
+
) -> go.Figure:
|
|
647
|
+
"""Create comprehensive multi-panel feature importance summary visualization.
|
|
648
|
+
|
|
649
|
+
Combines multiple views into a single figure:
|
|
650
|
+
- Top-left: Bar chart of consensus rankings
|
|
651
|
+
- Top-right: Method agreement heatmap
|
|
652
|
+
- Bottom: Distribution of importance scores
|
|
653
|
+
|
|
654
|
+
This provides a complete overview of feature importance analysis in one plot,
|
|
655
|
+
ideal for reports and presentations.
|
|
656
|
+
|
|
657
|
+
Parameters
|
|
658
|
+
----------
|
|
659
|
+
results : dict[str, Any]
|
|
660
|
+
Results from analyze_ml_importance() containing all required data
|
|
661
|
+
title : str | None, optional
|
|
662
|
+
Overall figure title. If None, uses "Feature Importance Analysis - Summary"
|
|
663
|
+
top_n : int, optional
|
|
664
|
+
Number of top features to show in bar chart. Default is 15.
|
|
665
|
+
theme : str | None, optional
|
|
666
|
+
Theme name ("default", "dark", "print", "presentation").
|
|
667
|
+
If None, uses current global theme.
|
|
668
|
+
width : int | None, optional
|
|
669
|
+
Figure width in pixels. If None, uses 1400.
|
|
670
|
+
height : int | None, optional
|
|
671
|
+
Figure height in pixels. If None, uses 1000.
|
|
672
|
+
|
|
673
|
+
Returns
|
|
674
|
+
-------
|
|
675
|
+
go.Figure
|
|
676
|
+
Multi-panel Plotly figure with comprehensive importance summary
|
|
677
|
+
|
|
678
|
+
Raises
|
|
679
|
+
------
|
|
680
|
+
ValueError
|
|
681
|
+
If results dict is missing required keys or has invalid structure.
|
|
682
|
+
TypeError
|
|
683
|
+
If parameters have incorrect types.
|
|
684
|
+
|
|
685
|
+
Examples
|
|
686
|
+
--------
|
|
687
|
+
>>> from ml4t.diagnostic.evaluation import analyze_ml_importance
|
|
688
|
+
>>> from ml4t.diagnostic.visualization import plot_importance_summary
|
|
689
|
+
>>>
|
|
690
|
+
>>> # Analyze importance
|
|
691
|
+
>>> results = analyze_ml_importance(model, X, y)
|
|
692
|
+
>>>
|
|
693
|
+
>>> # Create comprehensive summary
|
|
694
|
+
>>> fig = plot_importance_summary(results)
|
|
695
|
+
>>> fig.show()
|
|
696
|
+
>>>
|
|
697
|
+
>>> # Save for report
|
|
698
|
+
>>> fig = plot_importance_summary(
|
|
699
|
+
... results,
|
|
700
|
+
... title="Model Feature Importance Analysis",
|
|
701
|
+
... theme="print",
|
|
702
|
+
... top_n=20
|
|
703
|
+
... )
|
|
704
|
+
>>> fig.write_html("importance_summary.html")
|
|
705
|
+
>>> fig.write_image("importance_summary.pdf")
|
|
706
|
+
|
|
707
|
+
Notes
|
|
708
|
+
-----
|
|
709
|
+
- This is the recommended visualization for comprehensive reports
|
|
710
|
+
- All panels use consistent theming and color schemes
|
|
711
|
+
- Interactive hover works independently for each panel
|
|
712
|
+
- May require large display or high resolution for optimal viewing
|
|
713
|
+
- Consider using individual plot functions for more customization
|
|
714
|
+
"""
|
|
715
|
+
# Validate inputs
|
|
716
|
+
validate_plot_results(
|
|
717
|
+
results,
|
|
718
|
+
required_keys=["consensus_ranking", "method_results", "method_agreement", "methods_run"],
|
|
719
|
+
function_name="plot_importance_summary",
|
|
720
|
+
)
|
|
721
|
+
theme = validate_theme(theme)
|
|
722
|
+
validate_positive_int(top_n, "top_n")
|
|
723
|
+
|
|
724
|
+
# Get theme configuration
|
|
725
|
+
theme_config = get_theme_config(theme)
|
|
726
|
+
|
|
727
|
+
# Create subplots: 2x2 layout
|
|
728
|
+
# Row 1: Bar chart (left), Heatmap (right)
|
|
729
|
+
# Row 2: Distribution (spans both columns)
|
|
730
|
+
fig = make_subplots(
|
|
731
|
+
rows=2,
|
|
732
|
+
cols=2,
|
|
733
|
+
subplot_titles=(
|
|
734
|
+
"Consensus Rankings (Top Features)",
|
|
735
|
+
"Method Agreement",
|
|
736
|
+
"Importance Score Distributions",
|
|
737
|
+
"", # Empty subtitle for merged cell
|
|
738
|
+
),
|
|
739
|
+
specs=[
|
|
740
|
+
[{"type": "bar"}, {"type": "heatmap"}],
|
|
741
|
+
[{"type": "histogram", "colspan": 2}, None],
|
|
742
|
+
],
|
|
743
|
+
vertical_spacing=0.15,
|
|
744
|
+
horizontal_spacing=0.12,
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
# === Panel 1: Bar chart ===
|
|
748
|
+
all_features = results["consensus_ranking"]
|
|
749
|
+
features = all_features[:top_n]
|
|
750
|
+
method_results = results["method_results"]
|
|
751
|
+
|
|
752
|
+
# Calculate average importance
|
|
753
|
+
importance_scores = []
|
|
754
|
+
for feat in features:
|
|
755
|
+
scores = []
|
|
756
|
+
for method_name, method_result in method_results.items():
|
|
757
|
+
importances = (
|
|
758
|
+
method_result["importances_mean"]
|
|
759
|
+
if method_name == "pfi"
|
|
760
|
+
else method_result["importances"]
|
|
761
|
+
)
|
|
762
|
+
method_features = method_result["feature_names"]
|
|
763
|
+
if feat in method_features:
|
|
764
|
+
idx = method_features.index(feat)
|
|
765
|
+
scores.append(importances[idx])
|
|
766
|
+
if scores:
|
|
767
|
+
importance_scores.append(float(np.mean(scores)))
|
|
768
|
+
else:
|
|
769
|
+
importance_scores.append(0.0)
|
|
770
|
+
|
|
771
|
+
colors_bar = get_colorscale("viridis")
|
|
772
|
+
|
|
773
|
+
fig.add_trace(
|
|
774
|
+
go.Bar(
|
|
775
|
+
x=importance_scores,
|
|
776
|
+
y=features,
|
|
777
|
+
orientation="h",
|
|
778
|
+
marker={
|
|
779
|
+
"color": importance_scores,
|
|
780
|
+
"colorscale": colors_bar,
|
|
781
|
+
"showscale": False,
|
|
782
|
+
},
|
|
783
|
+
hovertemplate="<b>%{y}</b><br>Importance: %{x:.4f}<extra></extra>",
|
|
784
|
+
showlegend=False,
|
|
785
|
+
),
|
|
786
|
+
row=1,
|
|
787
|
+
col=1,
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
# === Panel 2: Heatmap ===
|
|
791
|
+
methods = results["methods_run"]
|
|
792
|
+
n_methods = len(methods)
|
|
793
|
+
correlation_matrix = np.eye(n_methods)
|
|
794
|
+
method_agreement = results["method_agreement"]
|
|
795
|
+
|
|
796
|
+
for i, method1 in enumerate(methods):
|
|
797
|
+
for j, method2 in enumerate(methods):
|
|
798
|
+
if i < j:
|
|
799
|
+
key1 = f"{method1}_vs_{method2}"
|
|
800
|
+
key2 = f"{method2}_vs_{method1}"
|
|
801
|
+
corr = method_agreement.get(key1, method_agreement.get(key2, 0.0))
|
|
802
|
+
correlation_matrix[i, j] = corr
|
|
803
|
+
correlation_matrix[j, i] = corr
|
|
804
|
+
|
|
805
|
+
colors_heatmap = get_colorscale("rdbu")
|
|
806
|
+
|
|
807
|
+
fig.add_trace(
|
|
808
|
+
go.Heatmap(
|
|
809
|
+
z=correlation_matrix,
|
|
810
|
+
x=[m.upper() for m in methods],
|
|
811
|
+
y=[m.upper() for m in methods],
|
|
812
|
+
colorscale=colors_heatmap,
|
|
813
|
+
zmid=0,
|
|
814
|
+
zmin=-1,
|
|
815
|
+
zmax=1,
|
|
816
|
+
showscale=True,
|
|
817
|
+
colorbar={
|
|
818
|
+
"title": "Correlation",
|
|
819
|
+
"x": 1.15, # Position to right of subplot
|
|
820
|
+
"len": 0.4,
|
|
821
|
+
},
|
|
822
|
+
text=np.round(correlation_matrix, 2),
|
|
823
|
+
texttemplate="%{text}",
|
|
824
|
+
textfont={"size": 10},
|
|
825
|
+
hovertemplate=("<b>%{x}</b> vs <b>%{y}</b><br>Correlation: %{z:.3f}<extra></extra>"),
|
|
826
|
+
),
|
|
827
|
+
row=1,
|
|
828
|
+
col=2,
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
# === Panel 3: Distribution (overlay) ===
|
|
832
|
+
color_list_dist = get_color_scheme("set2")
|
|
833
|
+
colors_dist = (
|
|
834
|
+
color_list_dist[: len(methods)] if len(methods) <= len(color_list_dist) else color_list_dist
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
for i, method_name in enumerate(methods):
|
|
838
|
+
result = method_results[method_name]
|
|
839
|
+
scores = result["importances_mean"] if method_name == "pfi" else result["importances"]
|
|
840
|
+
|
|
841
|
+
fig.add_trace(
|
|
842
|
+
go.Histogram(
|
|
843
|
+
x=scores,
|
|
844
|
+
name=method_name.upper(),
|
|
845
|
+
nbinsx=30,
|
|
846
|
+
marker_color=colors_dist[i],
|
|
847
|
+
opacity=0.7,
|
|
848
|
+
hovertemplate=(
|
|
849
|
+
f"<b>{method_name.upper()}</b><br>Importance: %{{x:.4f}}<br>Count: %{{y}}<extra></extra>"
|
|
850
|
+
),
|
|
851
|
+
),
|
|
852
|
+
row=2,
|
|
853
|
+
col=1,
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
# Update axes
|
|
857
|
+
fig.update_xaxes(title_text="Importance Score", row=1, col=1)
|
|
858
|
+
fig.update_yaxes(title_text="Features", row=1, col=1)
|
|
859
|
+
fig.update_xaxes(title_text="", row=1, col=2)
|
|
860
|
+
fig.update_yaxes(title_text="", autorange="reversed", row=1, col=2)
|
|
861
|
+
fig.update_xaxes(title_text="Importance Score", row=2, col=1)
|
|
862
|
+
fig.update_yaxes(title_text="Frequency", row=2, col=1)
|
|
863
|
+
|
|
864
|
+
# Update layout
|
|
865
|
+
fig.update_layout(
|
|
866
|
+
title={
|
|
867
|
+
"text": title or "Feature Importance Analysis - Summary",
|
|
868
|
+
"x": 0.5,
|
|
869
|
+
"xanchor": "center",
|
|
870
|
+
},
|
|
871
|
+
barmode="overlay",
|
|
872
|
+
**theme_config["layout"],
|
|
873
|
+
width=width or 1400,
|
|
874
|
+
height=height or 1000,
|
|
875
|
+
showlegend=True,
|
|
876
|
+
legend={
|
|
877
|
+
"orientation": "h",
|
|
878
|
+
"yanchor": "bottom",
|
|
879
|
+
"y": 1.02,
|
|
880
|
+
"xanchor": "right",
|
|
881
|
+
"x": 1,
|
|
882
|
+
},
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
# Apply responsive layout
|
|
886
|
+
apply_responsive_layout(fig)
|
|
887
|
+
|
|
888
|
+
return fig
|