ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,649 @@
|
|
|
1
|
+
"""Importance data extraction for visualization layer.
|
|
2
|
+
|
|
3
|
+
Extracts comprehensive visualization data from feature importance analysis results.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from .types import (
|
|
14
|
+
FeatureDetailData,
|
|
15
|
+
ImportanceVizData,
|
|
16
|
+
LLMContextData,
|
|
17
|
+
MethodComparisonData,
|
|
18
|
+
MethodImportanceData,
|
|
19
|
+
UncertaintyData,
|
|
20
|
+
)
|
|
21
|
+
from .validation import _validate_lengths_match
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def extract_importance_viz_data(
|
|
25
|
+
importance_results: dict[str, Any],
|
|
26
|
+
include_uncertainty: bool = True,
|
|
27
|
+
include_distributions: bool = True,
|
|
28
|
+
include_per_feature: bool = True,
|
|
29
|
+
include_llm_context: bool = True,
|
|
30
|
+
) -> ImportanceVizData:
|
|
31
|
+
"""Extract comprehensive visualization data from importance analysis results.
|
|
32
|
+
|
|
33
|
+
This function transforms raw importance analysis results into a structured
|
|
34
|
+
format optimized for rich interactive visualization. It exposes all details
|
|
35
|
+
including per-method breakdowns, uncertainty estimates, per-feature views,
|
|
36
|
+
and auto-generated narratives.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
importance_results : dict
|
|
41
|
+
Results from analyze_ml_importance() containing:
|
|
42
|
+
- 'consensus_ranking': list of features in importance order
|
|
43
|
+
- 'method_results': dict of {method_name: method_result}
|
|
44
|
+
- 'method_agreement': dict of pairwise correlations
|
|
45
|
+
- 'interpretation': analysis interpretation
|
|
46
|
+
- 'warnings': list of warning messages
|
|
47
|
+
include_uncertainty : bool, default=True
|
|
48
|
+
Whether to compute and include uncertainty metrics (stability, CI).
|
|
49
|
+
Requires bootstrap or repeated analysis data.
|
|
50
|
+
include_distributions : bool, default=True
|
|
51
|
+
Whether to include full distributions (per-repeat values for PFI).
|
|
52
|
+
Useful for detailed uncertainty visualization.
|
|
53
|
+
include_per_feature : bool, default=True
|
|
54
|
+
Whether to create per-feature aggregated views.
|
|
55
|
+
Enables feature drill-down dashboards.
|
|
56
|
+
include_llm_context : bool, default=True
|
|
57
|
+
Whether to generate auto-narratives for LLM consumption.
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
ImportanceVizData
|
|
62
|
+
Complete structured data package with all visualization details.
|
|
63
|
+
See ImportanceVizData TypedDict for full structure.
|
|
64
|
+
|
|
65
|
+
Examples
|
|
66
|
+
--------
|
|
67
|
+
>>> from ml4t.diagnostic.evaluation import analyze_ml_importance
|
|
68
|
+
>>> from ml4t.diagnostic.visualization.data_extraction import extract_importance_viz_data
|
|
69
|
+
>>>
|
|
70
|
+
>>> # Analyze importance
|
|
71
|
+
>>> results = analyze_ml_importance(model, X, y, methods=['mdi', 'pfi'])
|
|
72
|
+
>>>
|
|
73
|
+
>>> # Extract visualization data
|
|
74
|
+
>>> viz_data = extract_importance_viz_data(results)
|
|
75
|
+
>>>
|
|
76
|
+
>>> # Access different views
|
|
77
|
+
>>> print(viz_data['summary']['n_features']) # High-level summary
|
|
78
|
+
>>> print(viz_data['per_method']['mdi']['ranking'][:5]) # Top 5 by MDI
|
|
79
|
+
>>> print(viz_data['per_feature']['momentum']['method_ranks']) # Feature detail
|
|
80
|
+
>>> print(viz_data['llm_context']['key_insights']) # Auto-generated insights
|
|
81
|
+
|
|
82
|
+
Notes
|
|
83
|
+
-----
|
|
84
|
+
- The extracted data is designed for both human visualization and LLM interpretation
|
|
85
|
+
- Per-feature views enable drill-down dashboards
|
|
86
|
+
- Uncertainty metrics enable confidence visualization
|
|
87
|
+
- Auto-narratives prepare for future LLM integration
|
|
88
|
+
"""
|
|
89
|
+
# Extract basic info
|
|
90
|
+
consensus_ranking = importance_results.get("consensus_ranking", [])
|
|
91
|
+
method_results = importance_results.get("method_results", {})
|
|
92
|
+
method_agreement = importance_results.get("method_agreement", {})
|
|
93
|
+
interpretation = importance_results.get("interpretation", {})
|
|
94
|
+
warnings = importance_results.get("warnings", [])
|
|
95
|
+
methods_run = importance_results.get("methods_run", list(method_results.keys()))
|
|
96
|
+
|
|
97
|
+
n_features = len(consensus_ranking)
|
|
98
|
+
n_methods = len(methods_run)
|
|
99
|
+
|
|
100
|
+
# Build summary
|
|
101
|
+
summary = _build_summary(
|
|
102
|
+
consensus_ranking, method_agreement, methods_run, n_features, n_methods, warnings
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Extract per-method details
|
|
106
|
+
per_method = _extract_per_method_data(
|
|
107
|
+
method_results, include_distributions=include_distributions
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Build per-feature aggregations
|
|
111
|
+
per_feature = {}
|
|
112
|
+
if include_per_feature:
|
|
113
|
+
per_feature = _build_per_feature_data(
|
|
114
|
+
consensus_ranking, method_results, method_agreement, methods_run
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Compute uncertainty metrics
|
|
118
|
+
uncertainty_data: UncertaintyData = {
|
|
119
|
+
"method_stability": {},
|
|
120
|
+
"rank_stability": {},
|
|
121
|
+
"confidence_intervals": {},
|
|
122
|
+
"coefficient_of_variation": {},
|
|
123
|
+
}
|
|
124
|
+
if include_uncertainty:
|
|
125
|
+
uncertainty_data = _compute_uncertainty_metrics(method_results, consensus_ranking)
|
|
126
|
+
|
|
127
|
+
# Build method comparison data
|
|
128
|
+
method_comparison = _build_method_comparison(method_agreement, method_results, methods_run)
|
|
129
|
+
|
|
130
|
+
# Build metadata
|
|
131
|
+
metadata = {
|
|
132
|
+
"n_features": n_features,
|
|
133
|
+
"n_methods": n_methods,
|
|
134
|
+
"methods_run": methods_run,
|
|
135
|
+
"analysis_timestamp": datetime.now().isoformat(),
|
|
136
|
+
"warnings": warnings,
|
|
137
|
+
"interpretation": interpretation,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
# Generate LLM context
|
|
141
|
+
llm_context: LLMContextData = {
|
|
142
|
+
"summary_narrative": "",
|
|
143
|
+
"key_insights": [],
|
|
144
|
+
"recommendations": [],
|
|
145
|
+
"caveats": [],
|
|
146
|
+
"analysis_quality": "medium",
|
|
147
|
+
}
|
|
148
|
+
if include_llm_context:
|
|
149
|
+
llm_context = _generate_llm_context(
|
|
150
|
+
summary, per_method, method_comparison, uncertainty_data, warnings
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return ImportanceVizData(
|
|
154
|
+
summary=summary,
|
|
155
|
+
per_method=per_method,
|
|
156
|
+
per_feature=per_feature,
|
|
157
|
+
uncertainty=uncertainty_data,
|
|
158
|
+
method_comparison=method_comparison,
|
|
159
|
+
metadata=metadata,
|
|
160
|
+
llm_context=llm_context,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
# =============================================================================
|
|
165
|
+
# Helper Functions
|
|
166
|
+
# =============================================================================
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _build_summary(
|
|
170
|
+
consensus_ranking: list[str],
|
|
171
|
+
method_agreement: dict[str, float],
|
|
172
|
+
methods_run: list[str],
|
|
173
|
+
n_features: int,
|
|
174
|
+
n_methods: int,
|
|
175
|
+
warnings: list[str],
|
|
176
|
+
) -> dict[str, Any]:
|
|
177
|
+
"""Build high-level summary statistics."""
|
|
178
|
+
# Compute average agreement
|
|
179
|
+
if method_agreement:
|
|
180
|
+
avg_agreement = float(np.mean(list(method_agreement.values())))
|
|
181
|
+
else:
|
|
182
|
+
avg_agreement = 1.0 if n_methods == 1 else 0.0
|
|
183
|
+
|
|
184
|
+
# Determine agreement level
|
|
185
|
+
if avg_agreement > 0.8:
|
|
186
|
+
agreement_level = "high"
|
|
187
|
+
elif avg_agreement > 0.6:
|
|
188
|
+
agreement_level = "medium"
|
|
189
|
+
else:
|
|
190
|
+
agreement_level = "low"
|
|
191
|
+
|
|
192
|
+
return {
|
|
193
|
+
"n_features": n_features,
|
|
194
|
+
"n_methods": n_methods,
|
|
195
|
+
"methods_run": methods_run,
|
|
196
|
+
"top_feature": consensus_ranking[0] if consensus_ranking else None,
|
|
197
|
+
"consensus_ranking": consensus_ranking,
|
|
198
|
+
"avg_method_agreement": avg_agreement,
|
|
199
|
+
"agreement_level": agreement_level,
|
|
200
|
+
"has_warnings": len(warnings) > 0,
|
|
201
|
+
"warnings_count": len(warnings),
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _extract_per_method_data(
|
|
206
|
+
method_results: dict[str, dict], include_distributions: bool = True
|
|
207
|
+
) -> dict[str, MethodImportanceData]:
|
|
208
|
+
"""Extract detailed per-method importance data with normalized values."""
|
|
209
|
+
per_method: dict[str, MethodImportanceData] = {}
|
|
210
|
+
|
|
211
|
+
for method_name, method_result in method_results.items():
|
|
212
|
+
feature_names = method_result.get("feature_names", [])
|
|
213
|
+
|
|
214
|
+
# Get importances based on method type
|
|
215
|
+
if method_name == "pfi":
|
|
216
|
+
importances_mean = method_result.get("importances_mean", [])
|
|
217
|
+
importances_std = method_result.get("importances_std", [])
|
|
218
|
+
importances_raw = method_result.get("importances_raw", [])
|
|
219
|
+
|
|
220
|
+
# Validate length consistency for PFI data
|
|
221
|
+
_validate_lengths_match(
|
|
222
|
+
("feature_names", feature_names),
|
|
223
|
+
("importances_mean", importances_mean),
|
|
224
|
+
("importances_std", importances_std),
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Normalize importances to sum to 1.0 (percentage basis)
|
|
228
|
+
total = sum(importances_mean)
|
|
229
|
+
if total > 0:
|
|
230
|
+
importances_mean = [imp / total for imp in importances_mean]
|
|
231
|
+
importances_std = [std / total for std in importances_std]
|
|
232
|
+
|
|
233
|
+
# Convert to dicts (strict=True since we validated above)
|
|
234
|
+
importances_dict = dict(zip(feature_names, importances_mean, strict=True))
|
|
235
|
+
std_dict = dict(zip(feature_names, importances_std, strict=True))
|
|
236
|
+
|
|
237
|
+
# Compute confidence intervals (95% assuming normal)
|
|
238
|
+
# Use standard error (std / sqrt(n_repeats)) for CI of the mean
|
|
239
|
+
n_repeats = method_result.get("n_repeats", 1)
|
|
240
|
+
sqrt_n = np.sqrt(max(n_repeats, 1))
|
|
241
|
+
ci_dict = {}
|
|
242
|
+
for feat, mean, std in zip(
|
|
243
|
+
feature_names, importances_mean, importances_std, strict=False
|
|
244
|
+
):
|
|
245
|
+
se = std / sqrt_n # Standard error of the mean
|
|
246
|
+
ci_dict[feat] = (float(mean - 1.96 * se), float(mean + 1.96 * se))
|
|
247
|
+
|
|
248
|
+
# Get raw values per repeat
|
|
249
|
+
raw_list = None
|
|
250
|
+
if include_distributions and importances_raw is not None and len(importances_raw) > 0:
|
|
251
|
+
raw_list = []
|
|
252
|
+
for repeat_values in importances_raw:
|
|
253
|
+
raw_list.append(dict(zip(feature_names, repeat_values, strict=False)))
|
|
254
|
+
|
|
255
|
+
per_method[method_name] = MethodImportanceData(
|
|
256
|
+
importances=importances_dict,
|
|
257
|
+
ranking=sorted(feature_names, key=lambda f: importances_dict[f], reverse=True),
|
|
258
|
+
std=std_dict,
|
|
259
|
+
confidence_intervals=ci_dict,
|
|
260
|
+
raw_values=raw_list,
|
|
261
|
+
metadata={
|
|
262
|
+
"n_repeats": method_result.get("n_repeats", 1),
|
|
263
|
+
"scoring": method_result.get("scoring", "unknown"),
|
|
264
|
+
},
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
else:
|
|
268
|
+
# MDI, MDA, SHAP - single value per feature
|
|
269
|
+
importances = method_result.get("importances", [])
|
|
270
|
+
|
|
271
|
+
# Validate length consistency for non-PFI methods
|
|
272
|
+
_validate_lengths_match(
|
|
273
|
+
("feature_names", feature_names),
|
|
274
|
+
("importances", importances),
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Normalize importances to sum to 1.0 (percentage basis)
|
|
278
|
+
# MDI is already normalized, but SHAP and others may not be
|
|
279
|
+
total = sum(importances)
|
|
280
|
+
if total > 0 and abs(total - 1.0) > 0.01: # Not already normalized
|
|
281
|
+
importances = [imp / total for imp in importances]
|
|
282
|
+
|
|
283
|
+
importances_dict = dict(zip(feature_names, importances, strict=True))
|
|
284
|
+
|
|
285
|
+
per_method[method_name] = MethodImportanceData(
|
|
286
|
+
importances=importances_dict,
|
|
287
|
+
ranking=sorted(feature_names, key=lambda f: importances_dict[f], reverse=True),
|
|
288
|
+
std=None,
|
|
289
|
+
confidence_intervals=None,
|
|
290
|
+
raw_values=None,
|
|
291
|
+
metadata={},
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
return per_method
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _build_per_feature_data(
|
|
298
|
+
consensus_ranking: list[str],
|
|
299
|
+
method_results: dict[str, dict],
|
|
300
|
+
_method_agreement: dict[str, float],
|
|
301
|
+
methods_run: list[str],
|
|
302
|
+
) -> dict[str, FeatureDetailData]:
|
|
303
|
+
"""Build per-feature aggregated views for drill-down."""
|
|
304
|
+
per_feature: dict[str, FeatureDetailData] = {}
|
|
305
|
+
|
|
306
|
+
# Create importance and ranking dicts per method
|
|
307
|
+
method_importances: dict[str, dict[str, float]] = {}
|
|
308
|
+
method_rankings: dict[str, list[str]] = {}
|
|
309
|
+
|
|
310
|
+
for method_name, method_result in method_results.items():
|
|
311
|
+
feature_names = method_result.get("feature_names", [])
|
|
312
|
+
|
|
313
|
+
if method_name == "pfi":
|
|
314
|
+
importances = method_result.get("importances_mean", [])
|
|
315
|
+
else:
|
|
316
|
+
importances = method_result.get("importances", [])
|
|
317
|
+
|
|
318
|
+
method_importances[method_name] = dict(zip(feature_names, importances, strict=False))
|
|
319
|
+
method_rankings[method_name] = sorted(
|
|
320
|
+
feature_names, key=lambda f: method_importances[method_name].get(f, 0), reverse=True
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# Build per-feature data
|
|
324
|
+
for consensus_rank, feature_name in enumerate(consensus_ranking, start=1):
|
|
325
|
+
method_ranks = {}
|
|
326
|
+
method_scores = {}
|
|
327
|
+
method_stds = {}
|
|
328
|
+
|
|
329
|
+
for method_name in methods_run:
|
|
330
|
+
# Get rank in this method (with safe index lookup)
|
|
331
|
+
try:
|
|
332
|
+
ranking_list = method_rankings.get(method_name, [])
|
|
333
|
+
method_ranks[method_name] = ranking_list.index(feature_name) + 1
|
|
334
|
+
except ValueError:
|
|
335
|
+
# Feature not found in ranking - assign last rank
|
|
336
|
+
method_ranks[method_name] = len(method_rankings.get(method_name, [])) + 1
|
|
337
|
+
|
|
338
|
+
# Get score in this method
|
|
339
|
+
method_scores[method_name] = method_importances.get(method_name, {}).get(
|
|
340
|
+
feature_name, 0.0
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Get std if available (PFI) - with bounds checking
|
|
344
|
+
if method_name == "pfi":
|
|
345
|
+
pfi_result = method_results.get("pfi", {})
|
|
346
|
+
feature_names_pfi = pfi_result.get("feature_names", [])
|
|
347
|
+
if feature_name in feature_names_pfi:
|
|
348
|
+
idx = feature_names_pfi.index(feature_name)
|
|
349
|
+
importances_std = pfi_result.get("importances_std", [])
|
|
350
|
+
# Check bounds before accessing
|
|
351
|
+
if idx < len(importances_std):
|
|
352
|
+
method_stds[method_name] = importances_std[idx]
|
|
353
|
+
|
|
354
|
+
# Determine agreement level for this feature
|
|
355
|
+
rank_variance = 0.0 # Initialize before conditional to avoid undefined
|
|
356
|
+
if len(method_ranks) > 1:
|
|
357
|
+
rank_variance = float(np.var(list(method_ranks.values())))
|
|
358
|
+
if rank_variance < 2:
|
|
359
|
+
agreement_level = "high"
|
|
360
|
+
elif rank_variance < 10:
|
|
361
|
+
agreement_level = "medium"
|
|
362
|
+
else:
|
|
363
|
+
agreement_level = "low"
|
|
364
|
+
else:
|
|
365
|
+
agreement_level = "n/a"
|
|
366
|
+
|
|
367
|
+
# Compute stability score (inverse of rank variance, normalized)
|
|
368
|
+
stability_score = 1.0 / (1.0 + rank_variance) if len(method_ranks) > 1 else 1.0
|
|
369
|
+
|
|
370
|
+
# Generate interpretation
|
|
371
|
+
interpretation = _generate_feature_interpretation(
|
|
372
|
+
feature_name, consensus_rank, method_ranks, agreement_level
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
per_feature[feature_name] = FeatureDetailData(
|
|
376
|
+
consensus_rank=consensus_rank,
|
|
377
|
+
consensus_score=float(np.mean(list(method_scores.values()))),
|
|
378
|
+
method_ranks=method_ranks,
|
|
379
|
+
method_scores=method_scores,
|
|
380
|
+
method_stds=method_stds,
|
|
381
|
+
agreement_level=agreement_level,
|
|
382
|
+
stability_score=float(stability_score),
|
|
383
|
+
interpretation=interpretation,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
return per_feature
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def _compute_uncertainty_metrics(
|
|
390
|
+
method_results: dict[str, dict], consensus_ranking: list[str]
|
|
391
|
+
) -> UncertaintyData:
|
|
392
|
+
"""Compute uncertainty and stability metrics."""
|
|
393
|
+
# For now, focus on PFI which has repeat data
|
|
394
|
+
pfi_result = method_results.get("pfi", {})
|
|
395
|
+
has_pfi = bool(pfi_result)
|
|
396
|
+
|
|
397
|
+
method_stability = {}
|
|
398
|
+
confidence_intervals: dict[str, dict[str, tuple[float, float]]] = {}
|
|
399
|
+
coefficient_of_variation: dict[str, dict[str, float]] = {}
|
|
400
|
+
rank_stability: dict[str, list[int]] = {}
|
|
401
|
+
|
|
402
|
+
if has_pfi:
|
|
403
|
+
feature_names = pfi_result.get("feature_names", [])
|
|
404
|
+
importances_mean = pfi_result.get("importances_mean", [])
|
|
405
|
+
importances_std = pfi_result.get("importances_std", [])
|
|
406
|
+
|
|
407
|
+
# Validate length consistency
|
|
408
|
+
_validate_lengths_match(
|
|
409
|
+
("feature_names", feature_names),
|
|
410
|
+
("importances_mean", importances_mean),
|
|
411
|
+
("importances_std", importances_std),
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
# Method stability: average CV across features
|
|
415
|
+
cvs = []
|
|
416
|
+
cv_dict = {}
|
|
417
|
+
for feat, mean, std in zip(feature_names, importances_mean, importances_std, strict=True):
|
|
418
|
+
if mean != 0:
|
|
419
|
+
cv = std / abs(mean)
|
|
420
|
+
cvs.append(cv)
|
|
421
|
+
cv_dict[feat] = float(cv)
|
|
422
|
+
else:
|
|
423
|
+
cv_dict[feat] = 0.0
|
|
424
|
+
|
|
425
|
+
method_stability["pfi"] = float(1.0 - np.mean(cvs)) if cvs else 1.0
|
|
426
|
+
coefficient_of_variation["pfi"] = cv_dict
|
|
427
|
+
|
|
428
|
+
# Confidence intervals (use standard error for CI of the mean)
|
|
429
|
+
n_repeats = pfi_result.get("n_repeats", 1)
|
|
430
|
+
sqrt_n = np.sqrt(max(n_repeats, 1))
|
|
431
|
+
ci_dict = {}
|
|
432
|
+
for feat, mean, std in zip(feature_names, importances_mean, importances_std, strict=True):
|
|
433
|
+
se = std / sqrt_n # Standard error of the mean
|
|
434
|
+
ci_dict[feat] = (float(mean - 1.96 * se), float(mean + 1.96 * se))
|
|
435
|
+
confidence_intervals["pfi"] = ci_dict
|
|
436
|
+
|
|
437
|
+
# Rank stability (if we had bootstrap data, we'd track rank distributions)
|
|
438
|
+
# For now, mark as placeholder
|
|
439
|
+
for feat in consensus_ranking:
|
|
440
|
+
rank_stability[feat] = [] # Placeholder for bootstrap ranks
|
|
441
|
+
|
|
442
|
+
return UncertaintyData(
|
|
443
|
+
method_stability=method_stability,
|
|
444
|
+
rank_stability=rank_stability,
|
|
445
|
+
confidence_intervals=confidence_intervals,
|
|
446
|
+
coefficient_of_variation=coefficient_of_variation,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def _build_method_comparison(
|
|
451
|
+
method_agreement: dict[str, float], method_results: dict[str, dict], methods_run: list[str]
|
|
452
|
+
) -> MethodComparisonData:
|
|
453
|
+
"""Build method comparison metrics."""
|
|
454
|
+
# Build correlation matrix
|
|
455
|
+
len(methods_run)
|
|
456
|
+
correlation_matrix = []
|
|
457
|
+
|
|
458
|
+
for method1 in methods_run:
|
|
459
|
+
row = []
|
|
460
|
+
for method2 in methods_run:
|
|
461
|
+
if method1 == method2:
|
|
462
|
+
row.append(1.0)
|
|
463
|
+
else:
|
|
464
|
+
# Find correlation in method_agreement dict
|
|
465
|
+
key1 = f"{method1}_vs_{method2}"
|
|
466
|
+
key2 = f"{method2}_vs_{method1}"
|
|
467
|
+
corr = method_agreement.get(key1, method_agreement.get(key2, 0.0))
|
|
468
|
+
row.append(float(corr))
|
|
469
|
+
correlation_matrix.append(row)
|
|
470
|
+
|
|
471
|
+
# Compute rank differences
|
|
472
|
+
method_rankings: dict[str, list[str]] = {}
|
|
473
|
+
for method_name, method_result in method_results.items():
|
|
474
|
+
feature_names = method_result.get("feature_names", [])
|
|
475
|
+
if method_name == "pfi":
|
|
476
|
+
importances = method_result.get("importances_mean", [])
|
|
477
|
+
else:
|
|
478
|
+
importances = method_result.get("importances", [])
|
|
479
|
+
|
|
480
|
+
# Validate length consistency
|
|
481
|
+
_validate_lengths_match(
|
|
482
|
+
("feature_names", feature_names),
|
|
483
|
+
("importances", importances),
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
importances_dict = dict(zip(feature_names, importances, strict=True))
|
|
487
|
+
ranking = sorted(feature_names, key=lambda f: importances_dict[f], reverse=True)
|
|
488
|
+
method_rankings[method_name] = ranking
|
|
489
|
+
|
|
490
|
+
rank_differences: dict[tuple[str, str], dict[str, int]] = {}
|
|
491
|
+
for i, method1 in enumerate(methods_run):
|
|
492
|
+
for method2 in methods_run[i + 1 :]:
|
|
493
|
+
diff_dict = {}
|
|
494
|
+
ranking1 = method_rankings.get(method1, [])
|
|
495
|
+
ranking2 = method_rankings.get(method2, [])
|
|
496
|
+
|
|
497
|
+
for feat in ranking1:
|
|
498
|
+
if feat in ranking2:
|
|
499
|
+
rank1 = ranking1.index(feat) + 1
|
|
500
|
+
rank2 = ranking2.index(feat) + 1
|
|
501
|
+
diff_dict[feat] = abs(rank1 - rank2)
|
|
502
|
+
|
|
503
|
+
rank_differences[(method1, method2)] = diff_dict
|
|
504
|
+
|
|
505
|
+
return MethodComparisonData(
|
|
506
|
+
correlation_matrix=correlation_matrix,
|
|
507
|
+
correlation_methods=methods_run,
|
|
508
|
+
rank_differences=rank_differences,
|
|
509
|
+
agreement_summary=method_agreement,
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def _generate_feature_interpretation(
|
|
514
|
+
feature_name: str, consensus_rank: int, method_ranks: dict[str, int], agreement_level: str
|
|
515
|
+
) -> str:
|
|
516
|
+
"""Generate auto-interpretation for a single feature."""
|
|
517
|
+
if agreement_level == "high":
|
|
518
|
+
return (
|
|
519
|
+
f"'{feature_name}' ranks #{consensus_rank} with strong consensus across methods. "
|
|
520
|
+
f"All methods agree on its importance level."
|
|
521
|
+
)
|
|
522
|
+
elif agreement_level == "medium":
|
|
523
|
+
rank_str = ", ".join([f"{m}=#{r}" for m, r in method_ranks.items()])
|
|
524
|
+
return (
|
|
525
|
+
f"'{feature_name}' ranks #{consensus_rank} overall but shows moderate variation "
|
|
526
|
+
f"across methods ({rank_str}). Consider investigating method-specific biases."
|
|
527
|
+
)
|
|
528
|
+
else:
|
|
529
|
+
rank_str = ", ".join([f"{m}=#{r}" for m, r in method_ranks.items()])
|
|
530
|
+
return (
|
|
531
|
+
f"'{feature_name}' ranks #{consensus_rank} but shows significant disagreement "
|
|
532
|
+
f"across methods ({rank_str}). This may indicate interaction effects or "
|
|
533
|
+
f"method-specific artifacts. Further investigation recommended."
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
def _generate_llm_context(
|
|
538
|
+
summary: dict[str, Any],
|
|
539
|
+
_per_method: dict[str, MethodImportanceData],
|
|
540
|
+
_method_comparison: MethodComparisonData,
|
|
541
|
+
uncertainty: UncertaintyData,
|
|
542
|
+
warnings: list[str],
|
|
543
|
+
) -> LLMContextData:
|
|
544
|
+
"""Generate auto-narratives and insights for LLM consumption."""
|
|
545
|
+
n_features = summary["n_features"]
|
|
546
|
+
n_methods = summary["n_methods"]
|
|
547
|
+
methods_run = summary["methods_run"]
|
|
548
|
+
top_feature = summary["top_feature"]
|
|
549
|
+
avg_agreement = summary["avg_method_agreement"]
|
|
550
|
+
agreement_level = summary["agreement_level"]
|
|
551
|
+
|
|
552
|
+
# Build summary narrative
|
|
553
|
+
summary_narrative = (
|
|
554
|
+
f"This feature importance analysis examined {n_features} features using "
|
|
555
|
+
f"{n_methods} method{'s' if n_methods > 1 else ''} ({', '.join(methods_run)}). "
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
if top_feature:
|
|
559
|
+
summary_narrative += (
|
|
560
|
+
f"The consensus ranking identified '{top_feature}' as the most important feature. "
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
if n_methods > 1:
|
|
564
|
+
summary_narrative += (
|
|
565
|
+
f"Method agreement is {agreement_level} (average correlation: {avg_agreement:.2f}). "
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
# Generate key insights
|
|
569
|
+
key_insights = []
|
|
570
|
+
|
|
571
|
+
# Insight 1: Top features
|
|
572
|
+
key_insights.append(
|
|
573
|
+
f"Top consensus feature: '{top_feature}'"
|
|
574
|
+
if top_feature
|
|
575
|
+
else "No clear top feature identified"
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
# Insight 2: Method agreement
|
|
579
|
+
if n_methods > 1:
|
|
580
|
+
if agreement_level == "high":
|
|
581
|
+
key_insights.append(
|
|
582
|
+
f"Strong consensus across methods (avg correlation: {avg_agreement:.2f})"
|
|
583
|
+
)
|
|
584
|
+
elif agreement_level == "medium":
|
|
585
|
+
key_insights.append(
|
|
586
|
+
f"Moderate method agreement (avg correlation: {avg_agreement:.2f}) - some variation expected"
|
|
587
|
+
)
|
|
588
|
+
else:
|
|
589
|
+
key_insights.append(
|
|
590
|
+
f"Low method agreement (avg correlation: {avg_agreement:.2f}) - investigate method-specific biases"
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
# Insight 3: Stability (if available)
|
|
594
|
+
if uncertainty.get("method_stability"):
|
|
595
|
+
for method, stability in uncertainty["method_stability"].items():
|
|
596
|
+
if stability < 0.7:
|
|
597
|
+
key_insights.append(
|
|
598
|
+
f"{method.upper()} shows low stability (score: {stability:.2f}) - "
|
|
599
|
+
"importance estimates have high variance"
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
# Generate recommendations
|
|
603
|
+
recommendations = []
|
|
604
|
+
|
|
605
|
+
# Rec 1: Based on agreement
|
|
606
|
+
if n_methods > 1 and avg_agreement < 0.6:
|
|
607
|
+
recommendations.append(
|
|
608
|
+
"Investigate features with large rank disagreements between methods. "
|
|
609
|
+
"This may indicate interaction effects or method-specific artifacts."
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
# Rec 2: Based on stability
|
|
613
|
+
if uncertainty.get("method_stability") and any(
|
|
614
|
+
s < 0.7 for s in uncertainty["method_stability"].values()
|
|
615
|
+
):
|
|
616
|
+
recommendations.append(
|
|
617
|
+
"Increase number of repeats or use cross-validation to improve importance stability estimates."
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
# Rec 3: General best practice
|
|
621
|
+
recommendations.append(
|
|
622
|
+
"Focus on top consensus features for model interpretability and feature selection."
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
# Caveats
|
|
626
|
+
caveats = []
|
|
627
|
+
if warnings:
|
|
628
|
+
caveats.append(f"Analysis generated {len(warnings)} warning(s) - review carefully.")
|
|
629
|
+
|
|
630
|
+
if n_methods == 1:
|
|
631
|
+
caveats.append(
|
|
632
|
+
"Only one method used. Consider running multiple methods to validate findings."
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
# Determine overall quality
|
|
636
|
+
if n_methods >= 2 and avg_agreement > 0.7 and len(warnings) == 0:
|
|
637
|
+
analysis_quality = "high"
|
|
638
|
+
elif n_methods >= 2 and avg_agreement > 0.5:
|
|
639
|
+
analysis_quality = "medium"
|
|
640
|
+
else:
|
|
641
|
+
analysis_quality = "low"
|
|
642
|
+
|
|
643
|
+
return LLMContextData(
|
|
644
|
+
summary_narrative=summary_narrative,
|
|
645
|
+
key_insights=key_insights,
|
|
646
|
+
recommendations=recommendations,
|
|
647
|
+
caveats=caveats,
|
|
648
|
+
analysis_quality=analysis_quality,
|
|
649
|
+
)
|