ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,504 @@
|
|
|
1
|
+
"""Interaction data extraction for visualization layer.
|
|
2
|
+
|
|
3
|
+
Extracts comprehensive visualization data from feature interaction analysis results.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Any, cast
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from .types import (
|
|
14
|
+
FeatureInteractionData,
|
|
15
|
+
InteractionMatrixData,
|
|
16
|
+
InteractionVizData,
|
|
17
|
+
LLMContextData,
|
|
18
|
+
NetworkGraphData,
|
|
19
|
+
)
|
|
20
|
+
from .validation import _validate_matrix_feature_alignment
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def extract_interaction_viz_data(
|
|
24
|
+
interaction_results: dict[str, Any],
|
|
25
|
+
importance_results: dict[str, Any] | None = None,
|
|
26
|
+
n_top_partners: int = 5,
|
|
27
|
+
cluster_threshold: float = 0.3,
|
|
28
|
+
include_llm_context: bool = True,
|
|
29
|
+
) -> InteractionVizData:
|
|
30
|
+
"""Extract comprehensive visualization data from interaction analysis results.
|
|
31
|
+
|
|
32
|
+
This function transforms raw SHAP interaction results into structured data
|
|
33
|
+
optimized for rich interactive visualization, including per-feature summaries,
|
|
34
|
+
network graph data, interaction matrices, and auto-generated insights.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
interaction_results : dict
|
|
39
|
+
Results from compute_shap_interactions() containing:
|
|
40
|
+
- 'interaction_matrix': DataFrame with pairwise interactions
|
|
41
|
+
- 'feature_names': list of feature names
|
|
42
|
+
- 'shap_values': raw SHAP values (optional)
|
|
43
|
+
- 'shap_interaction_values': raw interaction values (optional)
|
|
44
|
+
importance_results : dict, optional
|
|
45
|
+
Optional importance results to cross-reference for node sizing.
|
|
46
|
+
If provided, will use consensus ranking to size network nodes.
|
|
47
|
+
n_top_partners : int, default=5
|
|
48
|
+
Number of top interaction partners to include per feature.
|
|
49
|
+
cluster_threshold : float, default=0.3
|
|
50
|
+
Minimum interaction strength to consider for clustering.
|
|
51
|
+
Features with interactions above this threshold are clustered.
|
|
52
|
+
include_llm_context : bool, default=True
|
|
53
|
+
Whether to generate auto-narratives for LLM consumption.
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
InteractionVizData
|
|
58
|
+
Complete structured data package with:
|
|
59
|
+
- Per-feature interaction summaries
|
|
60
|
+
- Network graph data (nodes, edges, clusters)
|
|
61
|
+
- Interaction matrix data
|
|
62
|
+
- Strength distribution statistics
|
|
63
|
+
- Auto-generated LLM narratives
|
|
64
|
+
|
|
65
|
+
Examples
|
|
66
|
+
--------
|
|
67
|
+
>>> from ml4t.diagnostic.evaluation import compute_shap_interactions
|
|
68
|
+
>>> from ml4t.diagnostic.visualization.data_extraction import extract_interaction_viz_data
|
|
69
|
+
>>>
|
|
70
|
+
>>> # Compute interactions
|
|
71
|
+
>>> interaction_results = compute_shap_interactions(model, X, y)
|
|
72
|
+
>>>
|
|
73
|
+
>>> # Extract visualization data
|
|
74
|
+
>>> viz_data = extract_interaction_viz_data(interaction_results)
|
|
75
|
+
>>>
|
|
76
|
+
>>> # Access different views
|
|
77
|
+
>>> print(viz_data['summary']['strongest_interaction'])
|
|
78
|
+
>>> print(viz_data['per_feature']['momentum']['top_partners'][:3])
|
|
79
|
+
>>> print(viz_data['network_graph']['nodes'])
|
|
80
|
+
>>> print(viz_data['llm_context']['key_insights'])
|
|
81
|
+
|
|
82
|
+
Notes
|
|
83
|
+
-----
|
|
84
|
+
- Network graph data is pre-computed for custom rendering
|
|
85
|
+
- Clustering identifies groups of strongly interacting features
|
|
86
|
+
- Per-feature summaries enable drill-down dashboards
|
|
87
|
+
- Cross-referencing with importance results enables better node sizing
|
|
88
|
+
"""
|
|
89
|
+
# Extract basic info
|
|
90
|
+
interaction_matrix_df = interaction_results.get("interaction_matrix")
|
|
91
|
+
feature_names = interaction_results.get("feature_names", [])
|
|
92
|
+
|
|
93
|
+
if interaction_matrix_df is None:
|
|
94
|
+
raise ValueError("interaction_results must contain 'interaction_matrix'")
|
|
95
|
+
|
|
96
|
+
# Convert to numpy for easier manipulation
|
|
97
|
+
if hasattr(interaction_matrix_df, "to_numpy"):
|
|
98
|
+
interaction_matrix = interaction_matrix_df.to_numpy()
|
|
99
|
+
else:
|
|
100
|
+
interaction_matrix = np.array(interaction_matrix_df)
|
|
101
|
+
|
|
102
|
+
# Validate matrix dimensions match feature names
|
|
103
|
+
_validate_matrix_feature_alignment(interaction_matrix, feature_names)
|
|
104
|
+
|
|
105
|
+
n_features = len(feature_names)
|
|
106
|
+
|
|
107
|
+
# Build summary statistics
|
|
108
|
+
summary = _build_interaction_summary(interaction_matrix, feature_names)
|
|
109
|
+
|
|
110
|
+
# Build per-feature interaction data
|
|
111
|
+
per_feature = _build_per_feature_interactions(interaction_matrix, feature_names, n_top_partners)
|
|
112
|
+
|
|
113
|
+
# Build network graph data
|
|
114
|
+
network_graph = _build_network_graph(
|
|
115
|
+
interaction_matrix, feature_names, importance_results, cluster_threshold
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Build matrix data
|
|
119
|
+
matrix_data = _build_interaction_matrix_data(interaction_matrix, feature_names)
|
|
120
|
+
|
|
121
|
+
# Build strength distribution
|
|
122
|
+
strength_distribution = _build_strength_distribution(interaction_matrix)
|
|
123
|
+
|
|
124
|
+
# Build metadata
|
|
125
|
+
metadata = {
|
|
126
|
+
"n_features": n_features,
|
|
127
|
+
"n_interactions": int(n_features * (n_features - 1) / 2),
|
|
128
|
+
"analysis_timestamp": datetime.now().isoformat(),
|
|
129
|
+
"cluster_threshold": cluster_threshold,
|
|
130
|
+
"n_top_partners": n_top_partners,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
# Generate LLM context
|
|
134
|
+
llm_context: LLMContextData = {
|
|
135
|
+
"summary_narrative": "",
|
|
136
|
+
"key_insights": [],
|
|
137
|
+
"recommendations": [],
|
|
138
|
+
"caveats": [],
|
|
139
|
+
"analysis_quality": "medium",
|
|
140
|
+
}
|
|
141
|
+
if include_llm_context:
|
|
142
|
+
llm_context = _generate_interaction_llm_context(
|
|
143
|
+
summary, per_feature, network_graph, strength_distribution
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return InteractionVizData(
|
|
147
|
+
summary=summary,
|
|
148
|
+
per_feature=per_feature,
|
|
149
|
+
network_graph=network_graph,
|
|
150
|
+
interaction_matrix=matrix_data,
|
|
151
|
+
strength_distribution=strength_distribution,
|
|
152
|
+
metadata=metadata,
|
|
153
|
+
llm_context=llm_context,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# =============================================================================
|
|
158
|
+
# Interaction Analysis Helpers
|
|
159
|
+
# =============================================================================
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _build_interaction_summary(
|
|
163
|
+
interaction_matrix: np.ndarray, feature_names: list[str]
|
|
164
|
+
) -> dict[str, Any]:
|
|
165
|
+
"""Build high-level summary statistics for interactions."""
|
|
166
|
+
n_features = len(feature_names)
|
|
167
|
+
|
|
168
|
+
# Get upper triangle (exclude diagonal)
|
|
169
|
+
triu_indices = np.triu_indices(n_features, k=1)
|
|
170
|
+
interaction_values = interaction_matrix[triu_indices]
|
|
171
|
+
|
|
172
|
+
# Find strongest interaction
|
|
173
|
+
abs_values = np.abs(interaction_values)
|
|
174
|
+
max_idx = np.argmax(abs_values)
|
|
175
|
+
max_interaction = float(interaction_values[max_idx])
|
|
176
|
+
|
|
177
|
+
# Get feature pair for strongest interaction
|
|
178
|
+
i, j = triu_indices[0][max_idx], triu_indices[1][max_idx]
|
|
179
|
+
strongest_pair = (feature_names[i], feature_names[j])
|
|
180
|
+
|
|
181
|
+
# Compute distribution statistics
|
|
182
|
+
mean_interaction = float(np.mean(abs_values))
|
|
183
|
+
median_interaction = float(np.median(abs_values))
|
|
184
|
+
std_interaction = float(np.std(abs_values))
|
|
185
|
+
|
|
186
|
+
# Identify features with strongest overall interactions
|
|
187
|
+
total_interactions = np.sum(np.abs(interaction_matrix), axis=1)
|
|
188
|
+
top_idx = np.argmax(total_interactions)
|
|
189
|
+
most_interactive_feature = feature_names[top_idx]
|
|
190
|
+
|
|
191
|
+
return {
|
|
192
|
+
"n_features": n_features,
|
|
193
|
+
"n_interactions": len(interaction_values),
|
|
194
|
+
"strongest_interaction": max_interaction,
|
|
195
|
+
"strongest_pair": strongest_pair,
|
|
196
|
+
"mean_interaction": mean_interaction,
|
|
197
|
+
"median_interaction": median_interaction,
|
|
198
|
+
"std_interaction": std_interaction,
|
|
199
|
+
"most_interactive_feature": most_interactive_feature,
|
|
200
|
+
"max_total_interaction": float(total_interactions[top_idx]),
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _build_per_feature_interactions(
|
|
205
|
+
interaction_matrix: np.ndarray, feature_names: list[str], n_top_partners: int = 5
|
|
206
|
+
) -> dict[str, FeatureInteractionData]:
|
|
207
|
+
"""Build per-feature interaction summaries."""
|
|
208
|
+
per_feature: dict[str, FeatureInteractionData] = {}
|
|
209
|
+
n_features = len(feature_names)
|
|
210
|
+
|
|
211
|
+
for i, feature_name in enumerate(feature_names):
|
|
212
|
+
# Get all interactions for this feature
|
|
213
|
+
interactions = interaction_matrix[i, :]
|
|
214
|
+
|
|
215
|
+
# Exclude self-interaction
|
|
216
|
+
partner_indices = [j for j in range(n_features) if j != i]
|
|
217
|
+
partner_interactions = [(feature_names[j], float(interactions[j])) for j in partner_indices]
|
|
218
|
+
|
|
219
|
+
# Sort by absolute interaction strength
|
|
220
|
+
partner_interactions.sort(key=lambda x: abs(x[1]), reverse=True)
|
|
221
|
+
|
|
222
|
+
# Get top N partners
|
|
223
|
+
top_partners = partner_interactions[:n_top_partners]
|
|
224
|
+
|
|
225
|
+
# Total interaction strength
|
|
226
|
+
total_strength = float(np.sum(np.abs(interactions)))
|
|
227
|
+
|
|
228
|
+
# Generate interpretation
|
|
229
|
+
interpretation = _generate_interaction_interpretation(feature_name, top_partners)
|
|
230
|
+
|
|
231
|
+
per_feature[feature_name] = FeatureInteractionData(
|
|
232
|
+
feature_name=feature_name,
|
|
233
|
+
top_partners=top_partners,
|
|
234
|
+
total_interaction_strength=total_strength,
|
|
235
|
+
cluster_id=None, # Will be filled by clustering
|
|
236
|
+
interpretation=interpretation,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
return per_feature
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _build_network_graph(
|
|
243
|
+
interaction_matrix: np.ndarray,
|
|
244
|
+
feature_names: list[str],
|
|
245
|
+
importance_results: dict[str, Any] | None,
|
|
246
|
+
cluster_threshold: float,
|
|
247
|
+
) -> NetworkGraphData:
|
|
248
|
+
"""Build network graph data (nodes, edges, clusters)."""
|
|
249
|
+
n_features = len(feature_names)
|
|
250
|
+
|
|
251
|
+
# Build nodes
|
|
252
|
+
nodes = []
|
|
253
|
+
for i, feature_name in enumerate(feature_names):
|
|
254
|
+
# Node importance (for sizing) - use importance if available
|
|
255
|
+
if importance_results and "consensus_ranking" in importance_results:
|
|
256
|
+
consensus_ranking = importance_results["consensus_ranking"]
|
|
257
|
+
if feature_name in consensus_ranking:
|
|
258
|
+
rank = consensus_ranking.index(feature_name) + 1
|
|
259
|
+
# Higher rank = smaller number = more important = larger node
|
|
260
|
+
node_importance = 1.0 / rank
|
|
261
|
+
else:
|
|
262
|
+
node_importance = 0.1
|
|
263
|
+
else:
|
|
264
|
+
# Use total interaction strength as proxy
|
|
265
|
+
node_importance = float(np.sum(np.abs(interaction_matrix[i, :])))
|
|
266
|
+
|
|
267
|
+
nodes.append(
|
|
268
|
+
{
|
|
269
|
+
"id": feature_name,
|
|
270
|
+
"label": feature_name,
|
|
271
|
+
"importance": node_importance,
|
|
272
|
+
"total_interaction": float(np.sum(np.abs(interaction_matrix[i, :]))),
|
|
273
|
+
}
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Build edges (only upper triangle to avoid duplicates)
|
|
277
|
+
edges = []
|
|
278
|
+
for i in range(n_features):
|
|
279
|
+
for j in range(i + 1, n_features):
|
|
280
|
+
interaction_value = float(interaction_matrix[i, j])
|
|
281
|
+
if abs(interaction_value) > 0: # Include all non-zero interactions
|
|
282
|
+
edges.append(
|
|
283
|
+
{
|
|
284
|
+
"source": feature_names[i],
|
|
285
|
+
"target": feature_names[j],
|
|
286
|
+
"weight": interaction_value,
|
|
287
|
+
"abs_weight": abs(interaction_value),
|
|
288
|
+
}
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# Sort edges by absolute weight
|
|
292
|
+
edges.sort(key=lambda e: cast(float, e["abs_weight"]), reverse=True)
|
|
293
|
+
|
|
294
|
+
# Perform simple clustering based on strong interactions
|
|
295
|
+
clusters = _detect_interaction_clusters(interaction_matrix, feature_names, cluster_threshold)
|
|
296
|
+
|
|
297
|
+
return NetworkGraphData(nodes=nodes, edges=edges, clusters=clusters)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _build_interaction_matrix_data(
|
|
301
|
+
interaction_matrix: np.ndarray, feature_names: list[str]
|
|
302
|
+
) -> InteractionMatrixData:
|
|
303
|
+
"""Build matrix data for heatmap visualization."""
|
|
304
|
+
# Convert to list of lists for JSON serialization
|
|
305
|
+
matrix_list = interaction_matrix.tolist()
|
|
306
|
+
|
|
307
|
+
# Compute statistics
|
|
308
|
+
triu_indices = np.triu_indices(len(feature_names), k=1)
|
|
309
|
+
interaction_values = interaction_matrix[triu_indices]
|
|
310
|
+
|
|
311
|
+
max_interaction = float(np.max(np.abs(interaction_values)))
|
|
312
|
+
mean_interaction = float(np.mean(np.abs(interaction_values)))
|
|
313
|
+
|
|
314
|
+
return InteractionMatrixData(
|
|
315
|
+
features=feature_names,
|
|
316
|
+
matrix=matrix_list,
|
|
317
|
+
max_interaction=max_interaction,
|
|
318
|
+
mean_interaction=mean_interaction,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def _build_strength_distribution(interaction_matrix: np.ndarray) -> dict[str, Any]:
|
|
323
|
+
"""Build distribution statistics for interaction strengths."""
|
|
324
|
+
n_features = interaction_matrix.shape[0]
|
|
325
|
+
triu_indices = np.triu_indices(n_features, k=1)
|
|
326
|
+
interaction_values = interaction_matrix[triu_indices]
|
|
327
|
+
abs_values = np.abs(interaction_values)
|
|
328
|
+
|
|
329
|
+
# Compute percentiles
|
|
330
|
+
percentiles = [10, 25, 50, 75, 90, 95, 99]
|
|
331
|
+
percentile_values = {f"p{p}": float(np.percentile(abs_values, p)) for p in percentiles}
|
|
332
|
+
|
|
333
|
+
# Binning for histogram
|
|
334
|
+
hist, bin_edges = np.histogram(abs_values, bins=20)
|
|
335
|
+
|
|
336
|
+
return {
|
|
337
|
+
"mean": float(np.mean(abs_values)),
|
|
338
|
+
"median": float(np.median(abs_values)),
|
|
339
|
+
"std": float(np.std(abs_values)),
|
|
340
|
+
"min": float(np.min(abs_values)),
|
|
341
|
+
"max": float(np.max(abs_values)),
|
|
342
|
+
"percentiles": percentile_values,
|
|
343
|
+
"histogram": {"counts": hist.tolist(), "bin_edges": bin_edges.tolist()},
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def _detect_interaction_clusters(
|
|
348
|
+
interaction_matrix: np.ndarray, feature_names: list[str], threshold: float
|
|
349
|
+
) -> list[list[str]]:
|
|
350
|
+
"""Detect clusters of strongly interacting features using simple thresholding.
|
|
351
|
+
|
|
352
|
+
This is a basic clustering approach based on connected components in the
|
|
353
|
+
interaction graph. More sophisticated methods could be added later.
|
|
354
|
+
"""
|
|
355
|
+
n_features = len(feature_names)
|
|
356
|
+
|
|
357
|
+
# Create adjacency matrix based on threshold
|
|
358
|
+
adj_matrix = np.abs(interaction_matrix) > threshold
|
|
359
|
+
np.fill_diagonal(adj_matrix, False) # No self-loops
|
|
360
|
+
|
|
361
|
+
# Find connected components (simple DFS)
|
|
362
|
+
visited = [False] * n_features
|
|
363
|
+
clusters = []
|
|
364
|
+
|
|
365
|
+
def dfs(node: int, cluster: list[int]) -> None:
|
|
366
|
+
visited[node] = True
|
|
367
|
+
cluster.append(node)
|
|
368
|
+
for neighbor in range(n_features):
|
|
369
|
+
if adj_matrix[node, neighbor] and not visited[neighbor]:
|
|
370
|
+
dfs(neighbor, cluster)
|
|
371
|
+
|
|
372
|
+
for i in range(n_features):
|
|
373
|
+
if not visited[i]:
|
|
374
|
+
cluster_indices: list[int] = []
|
|
375
|
+
dfs(i, cluster_indices)
|
|
376
|
+
if len(cluster_indices) > 1: # Only include clusters with >1 feature
|
|
377
|
+
clusters.append([feature_names[idx] for idx in cluster_indices])
|
|
378
|
+
|
|
379
|
+
return clusters
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def _generate_interaction_interpretation(
|
|
383
|
+
feature_name: str, top_partners: list[tuple[str, float]]
|
|
384
|
+
) -> str:
|
|
385
|
+
"""Generate auto-interpretation for a single feature's interactions."""
|
|
386
|
+
if not top_partners:
|
|
387
|
+
return f"'{feature_name}' has no significant interactions."
|
|
388
|
+
|
|
389
|
+
# Get top 3 for narrative
|
|
390
|
+
top_3 = top_partners[:3]
|
|
391
|
+
partner_str = ", ".join([f"'{p[0]}' ({p[1]:.3f})" for p in top_3])
|
|
392
|
+
|
|
393
|
+
return (
|
|
394
|
+
f"'{feature_name}' shows strongest interactions with {partner_str}. "
|
|
395
|
+
f"These interaction effects suggest the feature's predictive power "
|
|
396
|
+
f"depends on the values of these partner features."
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def _generate_interaction_llm_context(
|
|
401
|
+
summary: dict[str, Any],
|
|
402
|
+
_per_feature: dict[str, FeatureInteractionData],
|
|
403
|
+
network_graph: NetworkGraphData,
|
|
404
|
+
strength_distribution: dict[str, Any],
|
|
405
|
+
) -> LLMContextData:
|
|
406
|
+
"""Generate auto-narratives for interaction analysis."""
|
|
407
|
+
n_features = summary["n_features"]
|
|
408
|
+
n_interactions = summary["n_interactions"]
|
|
409
|
+
strongest_pair = summary["strongest_pair"]
|
|
410
|
+
strongest_value = summary["strongest_interaction"]
|
|
411
|
+
most_interactive = summary["most_interactive_feature"]
|
|
412
|
+
|
|
413
|
+
# Build summary narrative
|
|
414
|
+
summary_narrative = (
|
|
415
|
+
f"This interaction analysis examined {n_features} features, identifying "
|
|
416
|
+
f"{n_interactions} pairwise interactions. "
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
summary_narrative += (
|
|
420
|
+
f"The strongest interaction ({strongest_value:.3f}) occurs between "
|
|
421
|
+
f"'{strongest_pair[0]}' and '{strongest_pair[1]}'. "
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
if network_graph["clusters"]:
|
|
425
|
+
n_clusters = len(network_graph["clusters"])
|
|
426
|
+
summary_narrative += (
|
|
427
|
+
f"Cluster analysis identified {n_clusters} group(s) of strongly interacting features. "
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
# Key insights
|
|
431
|
+
key_insights = []
|
|
432
|
+
|
|
433
|
+
# Insight 1: Strongest interaction
|
|
434
|
+
key_insights.append(
|
|
435
|
+
f"Strongest interaction: {strongest_pair[0]} <-> {strongest_pair[1]} (strength: {strongest_value:.3f})"
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
# Insight 2: Most interactive feature
|
|
439
|
+
key_insights.append(
|
|
440
|
+
f"Most interactive feature: '{most_interactive}' (total interaction: {summary['max_total_interaction']:.3f})"
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# Insight 3: Distribution characteristics
|
|
444
|
+
mean_strength = strength_distribution["mean"]
|
|
445
|
+
median_strength = strength_distribution["median"]
|
|
446
|
+
if mean_strength > median_strength * 1.5:
|
|
447
|
+
key_insights.append(
|
|
448
|
+
f"Interaction strength distribution is right-skewed "
|
|
449
|
+
f"(mean: {mean_strength:.3f}, median: {median_strength:.3f}) - "
|
|
450
|
+
"a few strong interactions dominate"
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# Insight 4: Clustering
|
|
454
|
+
if network_graph["clusters"]:
|
|
455
|
+
largest_cluster = list(max(network_graph["clusters"], key=len)) # type: ignore[arg-type]
|
|
456
|
+
key_insights.append(
|
|
457
|
+
f"Largest interaction cluster has {len(largest_cluster)} features: "
|
|
458
|
+
f"{', '.join(largest_cluster[:5])}" + ("..." if len(largest_cluster) > 5 else "")
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Recommendations
|
|
462
|
+
recommendations = []
|
|
463
|
+
|
|
464
|
+
# Rec 1: Focus on strong interactions
|
|
465
|
+
recommendations.append(
|
|
466
|
+
f"Investigate the {strongest_pair[0]}/{strongest_pair[1]} interaction further. "
|
|
467
|
+
"Strong interactions suggest conditional effects or non-linear relationships."
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
# Rec 2: Feature engineering
|
|
471
|
+
if network_graph["clusters"]:
|
|
472
|
+
recommendations.append(
|
|
473
|
+
"Consider creating interaction features (products, ratios) for clustered "
|
|
474
|
+
"feature groups to capture non-linear effects explicitly."
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Rec 3: Model selection
|
|
478
|
+
recommendations.append(
|
|
479
|
+
"Tree-based models and neural networks can capture these interactions naturally. "
|
|
480
|
+
"Linear models may benefit from explicit interaction terms."
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# Caveats
|
|
484
|
+
caveats = [
|
|
485
|
+
"SHAP interactions measure feature contribution interactions, not statistical "
|
|
486
|
+
"correlations. High interaction doesn't imply high correlation.",
|
|
487
|
+
"Interaction values are model-specific and depend on the underlying model structure.",
|
|
488
|
+
]
|
|
489
|
+
|
|
490
|
+
# Determine quality
|
|
491
|
+
if n_features >= 5 and summary["max_total_interaction"] > 0.1:
|
|
492
|
+
analysis_quality = "high"
|
|
493
|
+
elif n_features >= 3:
|
|
494
|
+
analysis_quality = "medium"
|
|
495
|
+
else:
|
|
496
|
+
analysis_quality = "low"
|
|
497
|
+
|
|
498
|
+
return LLMContextData(
|
|
499
|
+
summary_narrative=summary_narrative,
|
|
500
|
+
key_insights=key_insights,
|
|
501
|
+
recommendations=recommendations,
|
|
502
|
+
caveats=caveats,
|
|
503
|
+
analysis_quality=analysis_quality,
|
|
504
|
+
)
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Type definitions for data extraction.
|
|
2
|
+
|
|
3
|
+
TypedDict classes for structured visualization data packages.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Any, TypedDict
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MethodImportanceData(TypedDict, total=False):
|
|
12
|
+
"""Importance data for a single method."""
|
|
13
|
+
|
|
14
|
+
importances: dict[str, float] # feature_name -> importance_score
|
|
15
|
+
ranking: list[str] # Features sorted by importance
|
|
16
|
+
std: dict[str, float] | None # Standard deviation if available (PFI)
|
|
17
|
+
confidence_intervals: dict[str, tuple[float, float]] | None # 95% CI if available
|
|
18
|
+
raw_values: list[dict[str, float]] | None # Per-repeat values (PFI)
|
|
19
|
+
metadata: dict[str, Any] # Method-specific metadata
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FeatureDetailData(TypedDict):
|
|
23
|
+
"""Complete data for a single feature across all analyses."""
|
|
24
|
+
|
|
25
|
+
consensus_rank: int # Overall ranking
|
|
26
|
+
consensus_score: float # Consensus importance score
|
|
27
|
+
method_ranks: dict[str, int] # Method name -> rank in that method
|
|
28
|
+
method_scores: dict[str, float] # Method name -> importance score
|
|
29
|
+
method_stds: dict[str, float] # Method name -> std dev (if available)
|
|
30
|
+
agreement_level: str # 'high', 'medium', 'low'
|
|
31
|
+
stability_score: float # 0-1, higher = more stable
|
|
32
|
+
interpretation: str # Auto-generated interpretation
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MethodComparisonData(TypedDict):
|
|
36
|
+
"""Method agreement and comparison metrics."""
|
|
37
|
+
|
|
38
|
+
correlation_matrix: list[list[float]] # Method x Method correlation matrix
|
|
39
|
+
correlation_methods: list[str] # Method names for matrix axes
|
|
40
|
+
rank_differences: dict[
|
|
41
|
+
tuple[str, str], dict[str, int]
|
|
42
|
+
] # (method1, method2) -> {feature: rank_diff}
|
|
43
|
+
agreement_summary: dict[str, float] # Pairwise correlations as dict
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class UncertaintyData(TypedDict):
|
|
47
|
+
"""Uncertainty and stability metrics."""
|
|
48
|
+
|
|
49
|
+
method_stability: dict[str, float] # Method -> stability score (0-1)
|
|
50
|
+
rank_stability: dict[str, list[int]] # Feature -> list of ranks across bootstraps
|
|
51
|
+
confidence_intervals: dict[str, dict[str, tuple[float, float]]] # Method -> {feature: (lo, hi)}
|
|
52
|
+
coefficient_of_variation: dict[str, dict[str, float]] # Method -> {feature: CV}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class LLMContextData(TypedDict):
|
|
56
|
+
"""Structured data for LLM interpretation."""
|
|
57
|
+
|
|
58
|
+
summary_narrative: str # High-level summary in natural language
|
|
59
|
+
key_insights: list[str] # Bullet points of findings
|
|
60
|
+
recommendations: list[str] # Actionable recommendations
|
|
61
|
+
caveats: list[str] # Limitations and warnings
|
|
62
|
+
analysis_quality: str # 'high', 'medium', 'low'
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ImportanceVizData(TypedDict):
|
|
66
|
+
"""Complete visualization data package for importance analysis."""
|
|
67
|
+
|
|
68
|
+
summary: dict[str, Any] # High-level metrics
|
|
69
|
+
per_method: dict[str, MethodImportanceData] # Method name -> detailed data
|
|
70
|
+
per_feature: dict[str, FeatureDetailData] # Feature name -> aggregated view
|
|
71
|
+
uncertainty: UncertaintyData # Stability and confidence metrics
|
|
72
|
+
method_comparison: MethodComparisonData # Cross-method analysis
|
|
73
|
+
metadata: dict[str, Any] # Context information
|
|
74
|
+
llm_context: LLMContextData # LLM-friendly narratives
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class FeatureInteractionData(TypedDict):
|
|
78
|
+
"""Interaction data for a single feature."""
|
|
79
|
+
|
|
80
|
+
feature_name: str
|
|
81
|
+
top_partners: list[tuple[str, float]] # (partner_feature, interaction_strength)
|
|
82
|
+
total_interaction_strength: float # Sum of absolute interactions
|
|
83
|
+
cluster_id: int | None # ID of interaction cluster (if clustering performed)
|
|
84
|
+
interpretation: str # Auto-generated interpretation
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class NetworkGraphData(TypedDict):
|
|
88
|
+
"""Network graph representation of interactions."""
|
|
89
|
+
|
|
90
|
+
nodes: list[dict[str, Any]] # [{id: str, label: str, importance: float, ...}]
|
|
91
|
+
edges: list[dict[str, Any]] # [{source: str, target: str, weight: float, ...}]
|
|
92
|
+
clusters: list[list[str]] # List of feature clusters based on interactions
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class InteractionMatrixData(TypedDict):
|
|
96
|
+
"""Matrix representation of pairwise interactions."""
|
|
97
|
+
|
|
98
|
+
features: list[str] # Ordered feature names
|
|
99
|
+
matrix: list[list[float]] # Symmetric interaction matrix
|
|
100
|
+
max_interaction: float # Maximum interaction value
|
|
101
|
+
mean_interaction: float # Mean interaction strength
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class InteractionVizData(TypedDict):
|
|
105
|
+
"""Complete visualization data package for interaction analysis."""
|
|
106
|
+
|
|
107
|
+
summary: dict[str, Any] # High-level metrics
|
|
108
|
+
per_feature: dict[str, FeatureInteractionData] # Feature -> interaction details
|
|
109
|
+
network_graph: NetworkGraphData # Graph visualization data
|
|
110
|
+
interaction_matrix: InteractionMatrixData # Matrix visualization data
|
|
111
|
+
strength_distribution: dict[str, Any] # Distribution of interaction strengths
|
|
112
|
+
metadata: dict[str, Any] # Context information
|
|
113
|
+
llm_context: LLMContextData # LLM-friendly narratives
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""Validation helpers for data extraction.
|
|
2
|
+
|
|
3
|
+
Provides length and dimension validation for extracted visualization data.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _validate_lengths_match(
|
|
12
|
+
*arrays: tuple[str, list | np.ndarray],
|
|
13
|
+
) -> None:
|
|
14
|
+
"""Validate that all provided arrays have matching lengths.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
*arrays : tuple[str, list | np.ndarray]
|
|
19
|
+
Tuples of (name, array) to validate.
|
|
20
|
+
|
|
21
|
+
Raises
|
|
22
|
+
------
|
|
23
|
+
ValueError
|
|
24
|
+
If arrays have different lengths.
|
|
25
|
+
"""
|
|
26
|
+
if not arrays:
|
|
27
|
+
return
|
|
28
|
+
|
|
29
|
+
lengths = [(name, len(arr)) for name, arr in arrays]
|
|
30
|
+
unique_lengths = {length for _, length in lengths}
|
|
31
|
+
|
|
32
|
+
if len(unique_lengths) > 1:
|
|
33
|
+
length_info = ", ".join(f"{name}={length}" for name, length in lengths)
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Length mismatch in data extraction: {length_info}. "
|
|
36
|
+
"All arrays must have the same length for consistent visualization."
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _validate_matrix_feature_alignment(matrix: np.ndarray, feature_names: list[str]) -> None:
|
|
41
|
+
"""Validate that interaction matrix dimensions match feature names.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
matrix : np.ndarray
|
|
46
|
+
Square interaction matrix.
|
|
47
|
+
feature_names : list[str]
|
|
48
|
+
Feature names for matrix axes.
|
|
49
|
+
|
|
50
|
+
Raises
|
|
51
|
+
------
|
|
52
|
+
ValueError
|
|
53
|
+
If matrix is not square or dimensions don't match feature count.
|
|
54
|
+
"""
|
|
55
|
+
n_features = len(feature_names)
|
|
56
|
+
if matrix.ndim != 2:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Interaction matrix must be 2D, got {matrix.ndim}D with shape {matrix.shape}"
|
|
59
|
+
)
|
|
60
|
+
if matrix.shape[0] != matrix.shape[1]:
|
|
61
|
+
raise ValueError(f"Interaction matrix must be square, got shape {matrix.shape}")
|
|
62
|
+
if matrix.shape[0] != n_features:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"Interaction matrix size ({matrix.shape[0]}) does not match "
|
|
65
|
+
f"number of features ({n_features})"
|
|
66
|
+
)
|