ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,772 @@
|
|
|
1
|
+
"""Feature interaction detection: H-statistic, SHAP interactions, and comprehensive analysis.
|
|
2
|
+
|
|
3
|
+
This module provides methods for detecting and analyzing feature interactions
|
|
4
|
+
including Friedman's H-statistic and SHAP interaction values.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import time
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Union, cast
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import polars as pl
|
|
13
|
+
from scipy.stats import spearmanr
|
|
14
|
+
|
|
15
|
+
from ml4t.diagnostic.evaluation.metrics.conditional_ic import compute_conditional_ic
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from numpy.typing import NDArray
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def compute_h_statistic(
|
|
22
|
+
model: Any,
|
|
23
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
24
|
+
feature_pairs: list[tuple[int, int]] | list[tuple[str, str]] | None = None,
|
|
25
|
+
feature_names: list[str] | None = None,
|
|
26
|
+
n_samples: int = 100,
|
|
27
|
+
grid_resolution: int = 20,
|
|
28
|
+
) -> dict[str, Any]:
|
|
29
|
+
"""Compute Friedman's H-statistic for feature interaction strength.
|
|
30
|
+
|
|
31
|
+
The H-statistic (Friedman & Popescu 2008) measures how much of the variation
|
|
32
|
+
in predictions can be attributed to interactions between feature pairs, beyond
|
|
33
|
+
their individual main effects.
|
|
34
|
+
|
|
35
|
+
**Algorithm**:
|
|
36
|
+
1. For each feature pair (j, k):
|
|
37
|
+
- Compute 2D partial dependence PD_{jk}(x_j, x_k)
|
|
38
|
+
- Compute 1D partial dependences PD_j(x_j) and PD_k(x_k)
|
|
39
|
+
- Compute H^2 = sum[PD_{jk} - PD_j - PD_k]^2 / sum[PD_{jk}^2]
|
|
40
|
+
- H ranges from 0 (no interaction) to 1 (pure interaction)
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
model : Any
|
|
45
|
+
Trained model with .predict() method
|
|
46
|
+
X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
|
|
47
|
+
Feature matrix (n_samples, n_features)
|
|
48
|
+
feature_pairs : list[tuple[int, int]] | list[tuple[str, str]] | None, default None
|
|
49
|
+
List of (i, j) pairs to test. If None, tests all pairs.
|
|
50
|
+
feature_names : list[str] | None, default None
|
|
51
|
+
Feature names. If None, uses column names or f0, f1, ...
|
|
52
|
+
n_samples : int, default 100
|
|
53
|
+
Number of samples to use for PD computation (subsample if needed)
|
|
54
|
+
grid_resolution : int, default 20
|
|
55
|
+
Grid size for PD evaluation
|
|
56
|
+
|
|
57
|
+
Returns
|
|
58
|
+
-------
|
|
59
|
+
dict[str, Any]
|
|
60
|
+
Dictionary with:
|
|
61
|
+
- h_statistics: List of (feature_i, feature_j, H_value) sorted by H descending
|
|
62
|
+
- feature_names: List of feature names used
|
|
63
|
+
- n_features: Number of features
|
|
64
|
+
- n_pairs_tested: Number of pairs tested
|
|
65
|
+
- computation_time: Time in seconds
|
|
66
|
+
|
|
67
|
+
References
|
|
68
|
+
----------
|
|
69
|
+
- Friedman, J. H., & Popescu, B. E. (2008). Predictive learning via rule ensembles.
|
|
70
|
+
The Annals of Applied Statistics, 2(3), 916-954.
|
|
71
|
+
|
|
72
|
+
Examples
|
|
73
|
+
--------
|
|
74
|
+
>>> import lightgbm as lgb
|
|
75
|
+
>>> model = lgb.LGBMRegressor()
|
|
76
|
+
>>> model.fit(X_train, y_train)
|
|
77
|
+
>>> results = compute_h_statistic(model, X_test)
|
|
78
|
+
>>> for feat_i, feat_j, h_val in results["h_statistics"][:5]:
|
|
79
|
+
... print(f" {feat_i} x {feat_j}: H = {h_val:.4f}")
|
|
80
|
+
"""
|
|
81
|
+
start_time = time.time()
|
|
82
|
+
|
|
83
|
+
# Convert input to numpy
|
|
84
|
+
if isinstance(X, pl.DataFrame):
|
|
85
|
+
if feature_names is None:
|
|
86
|
+
feature_names = X.columns
|
|
87
|
+
X_array = X.to_numpy()
|
|
88
|
+
elif isinstance(X, pd.DataFrame):
|
|
89
|
+
if feature_names is None:
|
|
90
|
+
feature_names = list(X.columns)
|
|
91
|
+
X_array = X.values
|
|
92
|
+
else: # numpy array
|
|
93
|
+
X_array = X
|
|
94
|
+
if feature_names is None:
|
|
95
|
+
feature_names = [f"f{i}" for i in range(X_array.shape[1])]
|
|
96
|
+
|
|
97
|
+
n_total_samples, n_features = X_array.shape
|
|
98
|
+
|
|
99
|
+
# Subsample if needed
|
|
100
|
+
if n_total_samples > n_samples:
|
|
101
|
+
rng = np.random.RandomState(42)
|
|
102
|
+
indices = rng.choice(n_total_samples, size=n_samples, replace=False)
|
|
103
|
+
X_sample = X_array[indices]
|
|
104
|
+
else:
|
|
105
|
+
X_sample = X_array
|
|
106
|
+
n_samples = n_total_samples
|
|
107
|
+
|
|
108
|
+
# Generate feature pairs if not provided - always convert to int pairs
|
|
109
|
+
pairs_int: list[tuple[int, int]]
|
|
110
|
+
if feature_pairs is None:
|
|
111
|
+
# Test all pairs
|
|
112
|
+
pairs_int = [(i, j) for i in range(n_features) for j in range(i + 1, n_features)]
|
|
113
|
+
elif feature_names and len(feature_pairs) > 0 and isinstance(feature_pairs[0][0], str):
|
|
114
|
+
# Convert string pairs to indices
|
|
115
|
+
name_to_idx = {name: idx for idx, name in enumerate(feature_names)}
|
|
116
|
+
pairs_int = [(name_to_idx[str(i)], name_to_idx[str(j)]) for i, j in feature_pairs]
|
|
117
|
+
else:
|
|
118
|
+
# Already integer pairs
|
|
119
|
+
pairs_int = [(int(i), int(j)) for i, j in feature_pairs]
|
|
120
|
+
|
|
121
|
+
# Ensure feature_names is a list for indexing
|
|
122
|
+
feature_names_list: list[str] = list(feature_names) if feature_names is not None else []
|
|
123
|
+
|
|
124
|
+
h_results: list[tuple[str, str, float]] = []
|
|
125
|
+
|
|
126
|
+
for feat_i, feat_j in pairs_int:
|
|
127
|
+
# Create grids for features i and j
|
|
128
|
+
x_i_grid = np.linspace(
|
|
129
|
+
float(X_sample[:, feat_i].min()), float(X_sample[:, feat_i].max()), grid_resolution
|
|
130
|
+
)
|
|
131
|
+
x_j_grid = np.linspace(
|
|
132
|
+
float(X_sample[:, feat_j].min()), float(X_sample[:, feat_j].max()), grid_resolution
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Compute 2D partial dependence PD_{ij}
|
|
136
|
+
pd_2d = np.zeros((grid_resolution, grid_resolution))
|
|
137
|
+
for gi, x_i_val in enumerate(x_i_grid):
|
|
138
|
+
for gj, x_j_val in enumerate(x_j_grid):
|
|
139
|
+
# Replace features i and j with grid values
|
|
140
|
+
X_temp = X_sample.copy()
|
|
141
|
+
X_temp[:, feat_i] = x_i_val
|
|
142
|
+
X_temp[:, feat_j] = x_j_val
|
|
143
|
+
# Average prediction over all samples
|
|
144
|
+
pd_2d[gi, gj] = model.predict(X_temp).mean()
|
|
145
|
+
|
|
146
|
+
# Compute 1D partial dependences PD_i and PD_j
|
|
147
|
+
pd_i = np.zeros(grid_resolution)
|
|
148
|
+
for gi, x_i_val in enumerate(x_i_grid):
|
|
149
|
+
X_temp = X_sample.copy()
|
|
150
|
+
X_temp[:, feat_i] = x_i_val
|
|
151
|
+
pd_i[gi] = model.predict(X_temp).mean()
|
|
152
|
+
|
|
153
|
+
pd_j = np.zeros(grid_resolution)
|
|
154
|
+
for gj, x_j_val in enumerate(x_j_grid):
|
|
155
|
+
X_temp = X_sample.copy()
|
|
156
|
+
X_temp[:, feat_j] = x_j_val
|
|
157
|
+
pd_j[gj] = model.predict(X_temp).mean()
|
|
158
|
+
|
|
159
|
+
# Compute H-statistic
|
|
160
|
+
# H^2 = sum[PD_{ij} - PD_i - PD_j + PD_const]^2 / sum[PD_{ij}^2]
|
|
161
|
+
|
|
162
|
+
# For numerical stability, center everything
|
|
163
|
+
pd_const = pd_2d.mean()
|
|
164
|
+
pd_i_centered = pd_i - pd_const
|
|
165
|
+
pd_j_centered = pd_j - pd_const
|
|
166
|
+
pd_2d_centered = pd_2d - pd_const
|
|
167
|
+
|
|
168
|
+
# Interaction component: PD_{ij} - PD_i - PD_j
|
|
169
|
+
# Need to broadcast pd_i and pd_j to 2D
|
|
170
|
+
pd_i_broadcast = pd_i_centered[:, np.newaxis] # Shape: (grid_resolution, 1)
|
|
171
|
+
pd_j_broadcast = pd_j_centered[np.newaxis, :] # Shape: (1, grid_resolution)
|
|
172
|
+
|
|
173
|
+
interaction = pd_2d_centered - pd_i_broadcast - pd_j_broadcast
|
|
174
|
+
|
|
175
|
+
# H-statistic
|
|
176
|
+
numerator = np.sum(interaction**2)
|
|
177
|
+
denominator = np.sum(pd_2d_centered**2)
|
|
178
|
+
|
|
179
|
+
if denominator > 1e-10: # Avoid division by zero
|
|
180
|
+
h_squared = numerator / denominator
|
|
181
|
+
h_stat = np.sqrt(max(0, h_squared)) # Ensure non-negative
|
|
182
|
+
else:
|
|
183
|
+
h_stat = 0.0
|
|
184
|
+
|
|
185
|
+
h_results.append((feature_names_list[feat_i], feature_names_list[feat_j], float(h_stat)))
|
|
186
|
+
|
|
187
|
+
# Sort by H-statistic descending
|
|
188
|
+
h_results.sort(key=lambda x: x[2], reverse=True)
|
|
189
|
+
|
|
190
|
+
computation_time = time.time() - start_time
|
|
191
|
+
|
|
192
|
+
return {
|
|
193
|
+
"h_statistics": h_results,
|
|
194
|
+
"feature_names": feature_names,
|
|
195
|
+
"n_features": n_features,
|
|
196
|
+
"n_pairs_tested": len(h_results),
|
|
197
|
+
"n_samples_used": n_samples,
|
|
198
|
+
"grid_resolution": grid_resolution,
|
|
199
|
+
"computation_time": computation_time,
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def compute_shap_interactions(
|
|
204
|
+
model: Any,
|
|
205
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
206
|
+
feature_names: list[str] | None = None,
|
|
207
|
+
_check_additivity: bool = False,
|
|
208
|
+
max_samples: int | None = None,
|
|
209
|
+
top_k: int | None = None,
|
|
210
|
+
) -> dict[str, Any]:
|
|
211
|
+
"""Compute SHAP interaction values for feature pairs.
|
|
212
|
+
|
|
213
|
+
SHAP interaction values decompose the SHAP value of each feature into:
|
|
214
|
+
- Main effect (the feature's individual contribution)
|
|
215
|
+
- Interaction effects (how the feature's impact changes with other features)
|
|
216
|
+
|
|
217
|
+
Parameters
|
|
218
|
+
----------
|
|
219
|
+
model : Any
|
|
220
|
+
Trained tree-based model
|
|
221
|
+
X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
|
|
222
|
+
Feature matrix (n_samples, n_features)
|
|
223
|
+
feature_names : list[str] | None, default None
|
|
224
|
+
Feature names. If None, uses column names or f0, f1, ...
|
|
225
|
+
_check_additivity : bool, default False
|
|
226
|
+
Internal parameter (not used for interaction values)
|
|
227
|
+
max_samples : int | None, default None
|
|
228
|
+
Maximum samples to use (subsample if larger)
|
|
229
|
+
top_k : int | None, default None
|
|
230
|
+
Return only top K interactions by absolute magnitude
|
|
231
|
+
|
|
232
|
+
Returns
|
|
233
|
+
-------
|
|
234
|
+
dict[str, Any]
|
|
235
|
+
Dictionary with:
|
|
236
|
+
- interaction_matrix: (n_features, n_features) mean absolute interactions
|
|
237
|
+
- feature_names: List of feature names
|
|
238
|
+
- top_interactions: List of (feature_i, feature_j, mean_interaction) sorted by magnitude
|
|
239
|
+
- n_features: Number of features
|
|
240
|
+
- n_samples_used: Number of samples used
|
|
241
|
+
- computation_time: Time in seconds
|
|
242
|
+
|
|
243
|
+
Notes
|
|
244
|
+
-----
|
|
245
|
+
- Requires shap package (install with: pip install ml4t-diagnostic[ml])
|
|
246
|
+
- Only works with tree-based models (uses TreeExplainer)
|
|
247
|
+
- Interaction matrix is symmetric: interaction(i,j) = interaction(j,i)
|
|
248
|
+
"""
|
|
249
|
+
start_time = time.time()
|
|
250
|
+
|
|
251
|
+
# Check shap availability
|
|
252
|
+
try:
|
|
253
|
+
import shap
|
|
254
|
+
except ImportError as e:
|
|
255
|
+
raise ImportError(
|
|
256
|
+
"SHAP is required for interaction values. "
|
|
257
|
+
"Install with: pip install ml4t-diagnostic[ml] "
|
|
258
|
+
"or: pip install shap>=0.43.0"
|
|
259
|
+
) from e
|
|
260
|
+
|
|
261
|
+
# Convert input to numpy and extract feature names
|
|
262
|
+
if isinstance(X, pl.DataFrame):
|
|
263
|
+
if feature_names is None:
|
|
264
|
+
feature_names = X.columns
|
|
265
|
+
X_array = X.to_numpy()
|
|
266
|
+
elif isinstance(X, pd.DataFrame):
|
|
267
|
+
if feature_names is None:
|
|
268
|
+
feature_names = list(X.columns)
|
|
269
|
+
X_array = X.values
|
|
270
|
+
else: # numpy array
|
|
271
|
+
X_array = X
|
|
272
|
+
if feature_names is None:
|
|
273
|
+
feature_names = [f"f{i}" for i in range(X_array.shape[1])]
|
|
274
|
+
|
|
275
|
+
# Type assertion: feature_names is guaranteed to be set at this point
|
|
276
|
+
assert feature_names is not None, "feature_names should be set by this point"
|
|
277
|
+
|
|
278
|
+
n_total_samples, n_features = X_array.shape
|
|
279
|
+
|
|
280
|
+
# Subsample if needed
|
|
281
|
+
if max_samples is not None and n_total_samples > max_samples:
|
|
282
|
+
rng = np.random.RandomState(42)
|
|
283
|
+
indices = rng.choice(n_total_samples, size=max_samples, replace=False)
|
|
284
|
+
X_sample = X_array[indices]
|
|
285
|
+
n_samples_used = max_samples
|
|
286
|
+
else:
|
|
287
|
+
X_sample = X_array
|
|
288
|
+
n_samples_used = n_total_samples
|
|
289
|
+
|
|
290
|
+
# Compute SHAP interaction values using TreeExplainer
|
|
291
|
+
explainer = shap.TreeExplainer(model)
|
|
292
|
+
shap_interaction_values = explainer.shap_interaction_values(X_sample)
|
|
293
|
+
|
|
294
|
+
# Handle multi-output models (classification)
|
|
295
|
+
if isinstance(shap_interaction_values, list):
|
|
296
|
+
# List format: use positive class for binary, average for multiclass
|
|
297
|
+
if len(shap_interaction_values) == 2:
|
|
298
|
+
shap_interaction_values = shap_interaction_values[1]
|
|
299
|
+
else:
|
|
300
|
+
shap_interaction_values = np.mean(shap_interaction_values, axis=0)
|
|
301
|
+
|
|
302
|
+
# Check if we have a 4D array (n_samples, n_features, n_features, n_classes)
|
|
303
|
+
if shap_interaction_values.ndim == 4:
|
|
304
|
+
if shap_interaction_values.shape[-1] == 2:
|
|
305
|
+
# Binary classification: use positive class (index 1)
|
|
306
|
+
shap_interaction_values = shap_interaction_values[:, :, :, 1]
|
|
307
|
+
else:
|
|
308
|
+
# Multiclass: average absolute values across classes
|
|
309
|
+
shap_interaction_values = np.mean(np.abs(shap_interaction_values), axis=-1)
|
|
310
|
+
|
|
311
|
+
# Shape should now be: (n_samples, n_features, n_features)
|
|
312
|
+
|
|
313
|
+
# Compute mean absolute interaction matrix
|
|
314
|
+
interaction_matrix = np.mean(np.abs(shap_interaction_values), axis=0)
|
|
315
|
+
|
|
316
|
+
# Ensure 2D matrix (n_features, n_features)
|
|
317
|
+
if interaction_matrix.ndim != 2:
|
|
318
|
+
raise ValueError(
|
|
319
|
+
f"Interaction matrix should be 2D but got shape {interaction_matrix.shape}. "
|
|
320
|
+
f"Raw SHAP values shape: {shap_interaction_values.shape}"
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# Extract top interactions (off-diagonal, upper triangle to avoid duplicates)
|
|
324
|
+
interactions_list = []
|
|
325
|
+
for i in range(n_features):
|
|
326
|
+
for j in range(i + 1, n_features): # Upper triangle only
|
|
327
|
+
mean_interaction = float(interaction_matrix[i, j])
|
|
328
|
+
interactions_list.append((feature_names[i], feature_names[j], mean_interaction))
|
|
329
|
+
|
|
330
|
+
# Sort by absolute interaction strength descending
|
|
331
|
+
interactions_list.sort(key=lambda x: abs(x[2]), reverse=True)
|
|
332
|
+
|
|
333
|
+
# Limit to top K if requested
|
|
334
|
+
if top_k is not None:
|
|
335
|
+
interactions_list = interactions_list[:top_k]
|
|
336
|
+
|
|
337
|
+
computation_time = time.time() - start_time
|
|
338
|
+
|
|
339
|
+
return {
|
|
340
|
+
"interaction_matrix": interaction_matrix,
|
|
341
|
+
"feature_names": feature_names,
|
|
342
|
+
"top_interactions": interactions_list,
|
|
343
|
+
"n_features": n_features,
|
|
344
|
+
"n_samples_used": n_samples_used,
|
|
345
|
+
"computation_time": computation_time,
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def _generate_interaction_interpretation(
|
|
350
|
+
top_interactions: list[tuple[str, str]],
|
|
351
|
+
method_agreement: dict[tuple[str, str], float],
|
|
352
|
+
warnings: list[str],
|
|
353
|
+
n_consensus: int,
|
|
354
|
+
) -> str:
|
|
355
|
+
"""Generate human-readable interpretation of interaction analysis.
|
|
356
|
+
|
|
357
|
+
Parameters
|
|
358
|
+
----------
|
|
359
|
+
top_interactions : list[tuple[str, str]]
|
|
360
|
+
Top feature pairs from consensus ranking
|
|
361
|
+
method_agreement : dict[tuple[str, str], float]
|
|
362
|
+
Pairwise correlations between method rankings
|
|
363
|
+
warnings : list[str]
|
|
364
|
+
List of potential issues detected
|
|
365
|
+
n_consensus : int
|
|
366
|
+
Number of interactions in top 10 across all methods
|
|
367
|
+
|
|
368
|
+
Returns
|
|
369
|
+
-------
|
|
370
|
+
str
|
|
371
|
+
Human-readable interpretation summary
|
|
372
|
+
"""
|
|
373
|
+
lines = []
|
|
374
|
+
|
|
375
|
+
# Consensus interactions
|
|
376
|
+
if n_consensus > 0:
|
|
377
|
+
lines.append(
|
|
378
|
+
f"Strong consensus: {n_consensus} interactions rank in top 10 across all methods"
|
|
379
|
+
)
|
|
380
|
+
pairs_str = ", ".join([f"({a}, {b})" for a, b in top_interactions[:3]])
|
|
381
|
+
lines.append(f" Top consensus interactions: {pairs_str}")
|
|
382
|
+
else:
|
|
383
|
+
lines.append("Weak consensus: Different methods identify different important interactions")
|
|
384
|
+
|
|
385
|
+
# Method agreement
|
|
386
|
+
if method_agreement:
|
|
387
|
+
avg_agreement = float(np.mean(list(method_agreement.values())))
|
|
388
|
+
if avg_agreement > 0.7:
|
|
389
|
+
lines.append(f"High agreement between methods (avg correlation: {avg_agreement:.2f})")
|
|
390
|
+
elif avg_agreement > 0.5:
|
|
391
|
+
lines.append(
|
|
392
|
+
f"Moderate agreement between methods (avg correlation: {avg_agreement:.2f})"
|
|
393
|
+
)
|
|
394
|
+
else:
|
|
395
|
+
lines.append(
|
|
396
|
+
f"Low agreement between methods (avg correlation: {avg_agreement:.2f}) - investigate further"
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# Warnings
|
|
400
|
+
if warnings:
|
|
401
|
+
lines.append("\nPotential Issues:")
|
|
402
|
+
for warning in warnings:
|
|
403
|
+
lines.append(f" - {warning}")
|
|
404
|
+
|
|
405
|
+
return "\n".join(lines)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def analyze_interactions(
|
|
409
|
+
model: Any,
|
|
410
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
411
|
+
y: Union[pl.Series, pd.Series, "NDArray[Any]"],
|
|
412
|
+
feature_pairs: list[tuple[str, str]] | None = None,
|
|
413
|
+
methods: list[str] | None = None,
|
|
414
|
+
n_quantiles: int = 5,
|
|
415
|
+
grid_resolution: int = 20,
|
|
416
|
+
max_samples: int = 200,
|
|
417
|
+
) -> dict[str, Any]:
|
|
418
|
+
"""Comprehensive feature interaction analysis comparing multiple methods.
|
|
419
|
+
|
|
420
|
+
**This is a TEAR SHEET function** - it runs multiple interaction detection methods
|
|
421
|
+
and generates a comparison report with consensus ranking and interpretation.
|
|
422
|
+
|
|
423
|
+
**Use Case**: "Which feature pairs interact in my model? Do different methods agree?"
|
|
424
|
+
|
|
425
|
+
This function replaces 100+ lines of manual comparison code by providing
|
|
426
|
+
integrated analysis showing:
|
|
427
|
+
- Individual method results (Conditional IC, H-statistic, SHAP interactions)
|
|
428
|
+
- Consensus ranking (interactions important across methods)
|
|
429
|
+
- Method agreement/disagreement analysis
|
|
430
|
+
- Auto-generated insights and warnings
|
|
431
|
+
|
|
432
|
+
Parameters
|
|
433
|
+
----------
|
|
434
|
+
model : Any
|
|
435
|
+
Fitted model. Requirements vary by method:
|
|
436
|
+
- Conditional IC: Not used (analyzes feature correlations)
|
|
437
|
+
- H-statistic: Must have `predict()` method
|
|
438
|
+
- SHAP: Must be compatible with TreeExplainer
|
|
439
|
+
X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
|
|
440
|
+
Feature matrix (n_samples, n_features)
|
|
441
|
+
y : Union[pl.Series, pd.Series, np.ndarray]
|
|
442
|
+
Target values (n_samples,)
|
|
443
|
+
feature_pairs : list[tuple[str, str]] | None, default None
|
|
444
|
+
Specific feature pairs to analyze. If None, tests all pairs.
|
|
445
|
+
methods : list[str] | None, default ["conditional_ic", "h_statistic", "shap"]
|
|
446
|
+
Which methods to run.
|
|
447
|
+
n_quantiles : int, default 5
|
|
448
|
+
Number of quantile bins for Conditional IC
|
|
449
|
+
grid_resolution : int, default 20
|
|
450
|
+
Grid size for partial dependence in H-statistic
|
|
451
|
+
max_samples : int, default 200
|
|
452
|
+
Maximum samples for SHAP and H-statistic
|
|
453
|
+
|
|
454
|
+
Returns
|
|
455
|
+
-------
|
|
456
|
+
dict[str, Any]
|
|
457
|
+
Comprehensive analysis results:
|
|
458
|
+
- method_results: Dict of individual method outputs
|
|
459
|
+
- consensus_ranking: Feature pairs ranked by average rank across methods
|
|
460
|
+
- method_agreement: Spearman correlations between method rankings
|
|
461
|
+
- top_interactions_consensus: Pairs in top 10 for ALL methods
|
|
462
|
+
- warnings: Detected issues
|
|
463
|
+
- interpretation: Auto-generated summary
|
|
464
|
+
- methods_run: Methods successfully executed
|
|
465
|
+
- methods_failed: Failed methods with error messages
|
|
466
|
+
|
|
467
|
+
Raises
|
|
468
|
+
------
|
|
469
|
+
ValueError
|
|
470
|
+
If all methods fail or no methods specified
|
|
471
|
+
"""
|
|
472
|
+
if methods is None:
|
|
473
|
+
methods = ["conditional_ic", "h_statistic", "shap"]
|
|
474
|
+
|
|
475
|
+
if not methods:
|
|
476
|
+
raise ValueError("At least one method must be specified")
|
|
477
|
+
|
|
478
|
+
# Extract feature names if not provided
|
|
479
|
+
if isinstance(X, pl.DataFrame | pd.DataFrame):
|
|
480
|
+
feature_names = list(X.columns)
|
|
481
|
+
else:
|
|
482
|
+
# Generate numeric feature names
|
|
483
|
+
n_features = X.shape[1] if hasattr(X, "shape") else len(X[0])
|
|
484
|
+
feature_names = [f"f{i}" for i in range(n_features)]
|
|
485
|
+
|
|
486
|
+
# Determine feature pairs to analyze
|
|
487
|
+
if feature_pairs is None:
|
|
488
|
+
# Test all pairs
|
|
489
|
+
n_features = len(feature_names)
|
|
490
|
+
all_pairs = []
|
|
491
|
+
for i in range(n_features):
|
|
492
|
+
for j in range(i + 1, n_features):
|
|
493
|
+
all_pairs.append((feature_names[i], feature_names[j]))
|
|
494
|
+
feature_pairs = all_pairs
|
|
495
|
+
else:
|
|
496
|
+
# Validate provided pairs
|
|
497
|
+
feature_set = set(feature_names)
|
|
498
|
+
for pair in feature_pairs:
|
|
499
|
+
if len(pair) != 2:
|
|
500
|
+
raise ValueError(f"Feature pair must have exactly 2 elements: {pair}")
|
|
501
|
+
if pair[0] not in feature_set or pair[1] not in feature_set:
|
|
502
|
+
raise ValueError(
|
|
503
|
+
f"Feature pair contains unknown features: {pair}. Available features: {feature_names}"
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
# Run each method with try/except for optional dependencies and errors
|
|
507
|
+
results = {}
|
|
508
|
+
method_failures = []
|
|
509
|
+
|
|
510
|
+
if "conditional_ic" in methods:
|
|
511
|
+
try:
|
|
512
|
+
# For Conditional IC, we need to run it for each pair
|
|
513
|
+
ic_results: list[tuple[str, str, float | None]] = []
|
|
514
|
+
for feat_a, feat_b in feature_pairs:
|
|
515
|
+
# Extract columns
|
|
516
|
+
x_a: pl.Series | pd.Series | NDArray[Any]
|
|
517
|
+
x_b: pl.Series | pd.Series | NDArray[Any]
|
|
518
|
+
if isinstance(X, pl.DataFrame):
|
|
519
|
+
x_a = X[feat_a]
|
|
520
|
+
x_b = X[feat_b]
|
|
521
|
+
elif isinstance(X, pd.DataFrame):
|
|
522
|
+
x_a = X[feat_a]
|
|
523
|
+
x_b = X[feat_b]
|
|
524
|
+
else:
|
|
525
|
+
# numpy array - need to find indices
|
|
526
|
+
idx_a = feature_names.index(feat_a)
|
|
527
|
+
idx_b = feature_names.index(feat_b)
|
|
528
|
+
X_arr = cast("NDArray[Any]", X)
|
|
529
|
+
x_a = X_arr[:, idx_a]
|
|
530
|
+
x_b = X_arr[:, idx_b]
|
|
531
|
+
|
|
532
|
+
result = compute_conditional_ic(
|
|
533
|
+
feature_a=x_a,
|
|
534
|
+
feature_b=x_b,
|
|
535
|
+
forward_returns=y,
|
|
536
|
+
n_quantiles=n_quantiles,
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
# Extract interaction strength metric
|
|
540
|
+
ic_range = result.get("ic_range", 0.0)
|
|
541
|
+
ic_results.append((feat_a, feat_b, ic_range))
|
|
542
|
+
|
|
543
|
+
# Sort by IC range descending
|
|
544
|
+
ic_results.sort(key=lambda x: abs(x[2]) if x[2] is not None else 0.0, reverse=True)
|
|
545
|
+
|
|
546
|
+
results["conditional_ic"] = {
|
|
547
|
+
"top_interactions": ic_results,
|
|
548
|
+
"n_pairs_tested": len(ic_results),
|
|
549
|
+
}
|
|
550
|
+
except Exception as e:
|
|
551
|
+
method_failures.append(("conditional_ic", str(e)))
|
|
552
|
+
|
|
553
|
+
if "h_statistic" in methods:
|
|
554
|
+
try:
|
|
555
|
+
# Convert feature pairs to indices for h_statistic
|
|
556
|
+
pair_indices = []
|
|
557
|
+
for feat_a, feat_b in feature_pairs:
|
|
558
|
+
idx_a = feature_names.index(feat_a)
|
|
559
|
+
idx_b = feature_names.index(feat_b)
|
|
560
|
+
pair_indices.append((idx_a, idx_b))
|
|
561
|
+
|
|
562
|
+
results["h_statistic"] = compute_h_statistic(
|
|
563
|
+
model,
|
|
564
|
+
X,
|
|
565
|
+
feature_pairs=pair_indices,
|
|
566
|
+
feature_names=feature_names,
|
|
567
|
+
n_samples=max_samples,
|
|
568
|
+
grid_resolution=grid_resolution,
|
|
569
|
+
)
|
|
570
|
+
except Exception as e:
|
|
571
|
+
method_failures.append(("h_statistic", str(e)))
|
|
572
|
+
|
|
573
|
+
if "shap" in methods:
|
|
574
|
+
try:
|
|
575
|
+
shap_result = compute_shap_interactions(
|
|
576
|
+
model,
|
|
577
|
+
X,
|
|
578
|
+
feature_names=feature_names,
|
|
579
|
+
max_samples=max_samples,
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
# Filter to requested pairs if feature_pairs was specified
|
|
583
|
+
if feature_pairs is not None:
|
|
584
|
+
pair_set = set(feature_pairs) | {(b, a) for a, b in feature_pairs}
|
|
585
|
+
filtered_interactions = [
|
|
586
|
+
(a, b, score)
|
|
587
|
+
for a, b, score in shap_result["top_interactions"]
|
|
588
|
+
if (a, b) in pair_set or (b, a) in pair_set
|
|
589
|
+
]
|
|
590
|
+
shap_result["top_interactions"] = filtered_interactions
|
|
591
|
+
|
|
592
|
+
results["shap"] = shap_result
|
|
593
|
+
except ImportError:
|
|
594
|
+
method_failures.append(
|
|
595
|
+
(
|
|
596
|
+
"shap",
|
|
597
|
+
"shap library not installed. Install with: pip install ml4t-diagnostic[ml]",
|
|
598
|
+
)
|
|
599
|
+
)
|
|
600
|
+
except Exception as e:
|
|
601
|
+
method_failures.append(("shap", str(e)))
|
|
602
|
+
|
|
603
|
+
# Check if at least one method succeeded
|
|
604
|
+
if not results:
|
|
605
|
+
error_msg = "All methods failed:\n" + "\n".join(
|
|
606
|
+
f" - {method}: {error}" for method, error in method_failures
|
|
607
|
+
)
|
|
608
|
+
raise ValueError(error_msg)
|
|
609
|
+
|
|
610
|
+
# 2. Compute consensus ranking
|
|
611
|
+
rankings: dict[str, NDArray[Any]] = {}
|
|
612
|
+
for method_name, result in results.items():
|
|
613
|
+
# Get interaction scores for this method
|
|
614
|
+
method_interactions: list[tuple[str, str, float]]
|
|
615
|
+
if "top_interactions" in result:
|
|
616
|
+
method_interactions = cast(list[tuple[str, str, float]], result["top_interactions"])
|
|
617
|
+
elif "h_statistics" in result:
|
|
618
|
+
method_interactions = cast(list[tuple[str, str, float]], result["h_statistics"])
|
|
619
|
+
else:
|
|
620
|
+
continue
|
|
621
|
+
|
|
622
|
+
# Create a mapping from pair to rank
|
|
623
|
+
pair_to_rank: dict[tuple[str, str], int] = {}
|
|
624
|
+
for rank_idx, interaction_tuple in enumerate(method_interactions):
|
|
625
|
+
feat_a_int, feat_b_int = str(interaction_tuple[0]), str(interaction_tuple[1])
|
|
626
|
+
pair_key = (min(feat_a_int, feat_b_int), max(feat_a_int, feat_b_int))
|
|
627
|
+
pair_to_rank[pair_key] = rank_idx
|
|
628
|
+
|
|
629
|
+
# Map all requested pairs to ranks (handle missing pairs)
|
|
630
|
+
ranks_array: list[int] = []
|
|
631
|
+
for feat_a, feat_b in feature_pairs:
|
|
632
|
+
pair_key = (min(feat_a, feat_b), max(feat_a, feat_b))
|
|
633
|
+
rank_val = pair_to_rank.get(pair_key, len(method_interactions))
|
|
634
|
+
ranks_array.append(rank_val)
|
|
635
|
+
|
|
636
|
+
rankings[method_name] = np.array(ranks_array)
|
|
637
|
+
|
|
638
|
+
# Average ranks across methods
|
|
639
|
+
avg_ranks = np.mean(list(rankings.values()), axis=0)
|
|
640
|
+
|
|
641
|
+
# Create consensus ranking with scores from each method
|
|
642
|
+
consensus_ranking: list[tuple[str, str, float, dict[str, float]]] = []
|
|
643
|
+
for idx, avg_rank in enumerate(avg_ranks):
|
|
644
|
+
feat_a, feat_b = feature_pairs[idx]
|
|
645
|
+
pair_tuple: tuple[str, str] = (min(feat_a, feat_b), max(feat_a, feat_b))
|
|
646
|
+
|
|
647
|
+
# Collect scores from each method
|
|
648
|
+
scores_dict: dict[str, float] = {}
|
|
649
|
+
for method_name, result in results.items():
|
|
650
|
+
method_ints: list[tuple[str, str, float]]
|
|
651
|
+
if "top_interactions" in result:
|
|
652
|
+
method_ints = cast(list[tuple[str, str, float]], result["top_interactions"])
|
|
653
|
+
elif "h_statistics" in result:
|
|
654
|
+
method_ints = cast(list[tuple[str, str, float]], result["h_statistics"])
|
|
655
|
+
else:
|
|
656
|
+
continue
|
|
657
|
+
|
|
658
|
+
for int_tuple in method_ints:
|
|
659
|
+
check_pair = (
|
|
660
|
+
min(str(int_tuple[0]), str(int_tuple[1])),
|
|
661
|
+
max(str(int_tuple[0]), str(int_tuple[1])),
|
|
662
|
+
)
|
|
663
|
+
if check_pair == pair_tuple:
|
|
664
|
+
scores_dict[method_name] = float(int_tuple[2])
|
|
665
|
+
break
|
|
666
|
+
|
|
667
|
+
consensus_ranking.append((feat_a, feat_b, float(avg_rank), scores_dict))
|
|
668
|
+
|
|
669
|
+
# Sort by average rank
|
|
670
|
+
consensus_ranking.sort(key=lambda x: x[2])
|
|
671
|
+
|
|
672
|
+
# 3. Compute method agreement (Spearman correlation between rankings)
|
|
673
|
+
method_agreement = {}
|
|
674
|
+
method_names = list(rankings.keys())
|
|
675
|
+
for i, m1 in enumerate(method_names):
|
|
676
|
+
for m2 in method_names[i + 1 :]:
|
|
677
|
+
corr, _ = spearmanr(rankings[m1], rankings[m2])
|
|
678
|
+
method_agreement[(m1, m2)] = float(corr)
|
|
679
|
+
|
|
680
|
+
# 4. Identify consensus top interactions (top 10 in all methods)
|
|
681
|
+
top_n = 10
|
|
682
|
+
top_interactions_by_method: dict[str, set[tuple[str, str]]] = {}
|
|
683
|
+
for method_name, result in results.items():
|
|
684
|
+
method_ints_list: list[tuple[str, str, float]]
|
|
685
|
+
if "top_interactions" in result:
|
|
686
|
+
method_ints_list = cast(list[tuple[str, str, float]], result["top_interactions"])
|
|
687
|
+
elif "h_statistics" in result:
|
|
688
|
+
method_ints_list = cast(list[tuple[str, str, float]], result["h_statistics"])
|
|
689
|
+
else:
|
|
690
|
+
continue
|
|
691
|
+
|
|
692
|
+
method_top_pairs: list[tuple[str, str]] = []
|
|
693
|
+
for int_entry in method_ints_list[:top_n]:
|
|
694
|
+
pair_sorted: tuple[str, str] = (
|
|
695
|
+
min(str(int_entry[0]), str(int_entry[1])),
|
|
696
|
+
max(str(int_entry[0]), str(int_entry[1])),
|
|
697
|
+
)
|
|
698
|
+
method_top_pairs.append(pair_sorted)
|
|
699
|
+
top_interactions_by_method[method_name] = set(method_top_pairs)
|
|
700
|
+
|
|
701
|
+
if top_interactions_by_method:
|
|
702
|
+
consensus_top_pairs = set.intersection(*top_interactions_by_method.values())
|
|
703
|
+
else:
|
|
704
|
+
consensus_top_pairs = set()
|
|
705
|
+
|
|
706
|
+
consensus_top_list = list(consensus_top_pairs)
|
|
707
|
+
|
|
708
|
+
# 5. Generate warnings
|
|
709
|
+
warnings = []
|
|
710
|
+
|
|
711
|
+
# Warning: Disagreement between specific methods
|
|
712
|
+
if "conditional_ic" in results and "h_statistic" in results:
|
|
713
|
+
ic_interactions: list[tuple[str, str, float]]
|
|
714
|
+
if "top_interactions" in results["conditional_ic"]:
|
|
715
|
+
ic_interactions = cast(
|
|
716
|
+
list[tuple[str, str, float]], results["conditional_ic"]["top_interactions"]
|
|
717
|
+
)
|
|
718
|
+
else:
|
|
719
|
+
ic_interactions = []
|
|
720
|
+
|
|
721
|
+
h_interactions: list[tuple[str, str, float]] = cast(
|
|
722
|
+
list[tuple[str, str, float]], results["h_statistic"].get("h_statistics", [])
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
ic_top: set[tuple[str, str]] = {
|
|
726
|
+
(min(str(x[0]), str(x[1])), max(str(x[0]), str(x[1]))) for x in ic_interactions[:5]
|
|
727
|
+
}
|
|
728
|
+
h_top: set[tuple[str, str]] = {
|
|
729
|
+
(min(str(x[0]), str(x[1])), max(str(x[0]), str(x[1]))) for x in h_interactions[:5]
|
|
730
|
+
}
|
|
731
|
+
|
|
732
|
+
disagreement = ic_top - h_top
|
|
733
|
+
if disagreement:
|
|
734
|
+
pairs_str = ", ".join([f"({a}, {b})" for a, b in disagreement])
|
|
735
|
+
warnings.append(
|
|
736
|
+
f"Pairs {pairs_str} rank high in Conditional IC but not H-statistic - "
|
|
737
|
+
"possible regime-specific interaction (time-varying)"
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
# Warning: Low agreement between methods
|
|
741
|
+
if method_agreement:
|
|
742
|
+
min_agreement = min(method_agreement.values())
|
|
743
|
+
if min_agreement < 0.5:
|
|
744
|
+
warnings.append(
|
|
745
|
+
f"Low agreement between methods (min correlation: {min_agreement:.2f}) - "
|
|
746
|
+
"results may be unreliable or methods capture different interaction types"
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
# Add method failures to warnings
|
|
750
|
+
if method_failures:
|
|
751
|
+
for method, error in method_failures:
|
|
752
|
+
warnings.append(f"Method '{method}' failed: {error}")
|
|
753
|
+
|
|
754
|
+
# 6. Generate interpretation
|
|
755
|
+
top_pairs = [(a, b) for a, b, _, _ in consensus_ranking[:10]]
|
|
756
|
+
interpretation = _generate_interaction_interpretation(
|
|
757
|
+
top_pairs,
|
|
758
|
+
method_agreement,
|
|
759
|
+
warnings,
|
|
760
|
+
len(consensus_top_list),
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
return {
|
|
764
|
+
"method_results": results,
|
|
765
|
+
"consensus_ranking": consensus_ranking,
|
|
766
|
+
"method_agreement": method_agreement,
|
|
767
|
+
"top_interactions_consensus": consensus_top_list,
|
|
768
|
+
"warnings": warnings,
|
|
769
|
+
"interpretation": interpretation,
|
|
770
|
+
"methods_run": list(results.keys()),
|
|
771
|
+
"methods_failed": method_failures,
|
|
772
|
+
}
|