ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,772 @@
1
+ """Feature interaction detection: H-statistic, SHAP interactions, and comprehensive analysis.
2
+
3
+ This module provides methods for detecting and analyzing feature interactions
4
+ including Friedman's H-statistic and SHAP interaction values.
5
+ """
6
+
7
+ import time
8
+ from typing import TYPE_CHECKING, Any, Union, cast
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import polars as pl
13
+ from scipy.stats import spearmanr
14
+
15
+ from ml4t.diagnostic.evaluation.metrics.conditional_ic import compute_conditional_ic
16
+
17
+ if TYPE_CHECKING:
18
+ from numpy.typing import NDArray
19
+
20
+
21
+ def compute_h_statistic(
22
+ model: Any,
23
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
24
+ feature_pairs: list[tuple[int, int]] | list[tuple[str, str]] | None = None,
25
+ feature_names: list[str] | None = None,
26
+ n_samples: int = 100,
27
+ grid_resolution: int = 20,
28
+ ) -> dict[str, Any]:
29
+ """Compute Friedman's H-statistic for feature interaction strength.
30
+
31
+ The H-statistic (Friedman & Popescu 2008) measures how much of the variation
32
+ in predictions can be attributed to interactions between feature pairs, beyond
33
+ their individual main effects.
34
+
35
+ **Algorithm**:
36
+ 1. For each feature pair (j, k):
37
+ - Compute 2D partial dependence PD_{jk}(x_j, x_k)
38
+ - Compute 1D partial dependences PD_j(x_j) and PD_k(x_k)
39
+ - Compute H^2 = sum[PD_{jk} - PD_j - PD_k]^2 / sum[PD_{jk}^2]
40
+ - H ranges from 0 (no interaction) to 1 (pure interaction)
41
+
42
+ Parameters
43
+ ----------
44
+ model : Any
45
+ Trained model with .predict() method
46
+ X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
47
+ Feature matrix (n_samples, n_features)
48
+ feature_pairs : list[tuple[int, int]] | list[tuple[str, str]] | None, default None
49
+ List of (i, j) pairs to test. If None, tests all pairs.
50
+ feature_names : list[str] | None, default None
51
+ Feature names. If None, uses column names or f0, f1, ...
52
+ n_samples : int, default 100
53
+ Number of samples to use for PD computation (subsample if needed)
54
+ grid_resolution : int, default 20
55
+ Grid size for PD evaluation
56
+
57
+ Returns
58
+ -------
59
+ dict[str, Any]
60
+ Dictionary with:
61
+ - h_statistics: List of (feature_i, feature_j, H_value) sorted by H descending
62
+ - feature_names: List of feature names used
63
+ - n_features: Number of features
64
+ - n_pairs_tested: Number of pairs tested
65
+ - computation_time: Time in seconds
66
+
67
+ References
68
+ ----------
69
+ - Friedman, J. H., & Popescu, B. E. (2008). Predictive learning via rule ensembles.
70
+ The Annals of Applied Statistics, 2(3), 916-954.
71
+
72
+ Examples
73
+ --------
74
+ >>> import lightgbm as lgb
75
+ >>> model = lgb.LGBMRegressor()
76
+ >>> model.fit(X_train, y_train)
77
+ >>> results = compute_h_statistic(model, X_test)
78
+ >>> for feat_i, feat_j, h_val in results["h_statistics"][:5]:
79
+ ... print(f" {feat_i} x {feat_j}: H = {h_val:.4f}")
80
+ """
81
+ start_time = time.time()
82
+
83
+ # Convert input to numpy
84
+ if isinstance(X, pl.DataFrame):
85
+ if feature_names is None:
86
+ feature_names = X.columns
87
+ X_array = X.to_numpy()
88
+ elif isinstance(X, pd.DataFrame):
89
+ if feature_names is None:
90
+ feature_names = list(X.columns)
91
+ X_array = X.values
92
+ else: # numpy array
93
+ X_array = X
94
+ if feature_names is None:
95
+ feature_names = [f"f{i}" for i in range(X_array.shape[1])]
96
+
97
+ n_total_samples, n_features = X_array.shape
98
+
99
+ # Subsample if needed
100
+ if n_total_samples > n_samples:
101
+ rng = np.random.RandomState(42)
102
+ indices = rng.choice(n_total_samples, size=n_samples, replace=False)
103
+ X_sample = X_array[indices]
104
+ else:
105
+ X_sample = X_array
106
+ n_samples = n_total_samples
107
+
108
+ # Generate feature pairs if not provided - always convert to int pairs
109
+ pairs_int: list[tuple[int, int]]
110
+ if feature_pairs is None:
111
+ # Test all pairs
112
+ pairs_int = [(i, j) for i in range(n_features) for j in range(i + 1, n_features)]
113
+ elif feature_names and len(feature_pairs) > 0 and isinstance(feature_pairs[0][0], str):
114
+ # Convert string pairs to indices
115
+ name_to_idx = {name: idx for idx, name in enumerate(feature_names)}
116
+ pairs_int = [(name_to_idx[str(i)], name_to_idx[str(j)]) for i, j in feature_pairs]
117
+ else:
118
+ # Already integer pairs
119
+ pairs_int = [(int(i), int(j)) for i, j in feature_pairs]
120
+
121
+ # Ensure feature_names is a list for indexing
122
+ feature_names_list: list[str] = list(feature_names) if feature_names is not None else []
123
+
124
+ h_results: list[tuple[str, str, float]] = []
125
+
126
+ for feat_i, feat_j in pairs_int:
127
+ # Create grids for features i and j
128
+ x_i_grid = np.linspace(
129
+ float(X_sample[:, feat_i].min()), float(X_sample[:, feat_i].max()), grid_resolution
130
+ )
131
+ x_j_grid = np.linspace(
132
+ float(X_sample[:, feat_j].min()), float(X_sample[:, feat_j].max()), grid_resolution
133
+ )
134
+
135
+ # Compute 2D partial dependence PD_{ij}
136
+ pd_2d = np.zeros((grid_resolution, grid_resolution))
137
+ for gi, x_i_val in enumerate(x_i_grid):
138
+ for gj, x_j_val in enumerate(x_j_grid):
139
+ # Replace features i and j with grid values
140
+ X_temp = X_sample.copy()
141
+ X_temp[:, feat_i] = x_i_val
142
+ X_temp[:, feat_j] = x_j_val
143
+ # Average prediction over all samples
144
+ pd_2d[gi, gj] = model.predict(X_temp).mean()
145
+
146
+ # Compute 1D partial dependences PD_i and PD_j
147
+ pd_i = np.zeros(grid_resolution)
148
+ for gi, x_i_val in enumerate(x_i_grid):
149
+ X_temp = X_sample.copy()
150
+ X_temp[:, feat_i] = x_i_val
151
+ pd_i[gi] = model.predict(X_temp).mean()
152
+
153
+ pd_j = np.zeros(grid_resolution)
154
+ for gj, x_j_val in enumerate(x_j_grid):
155
+ X_temp = X_sample.copy()
156
+ X_temp[:, feat_j] = x_j_val
157
+ pd_j[gj] = model.predict(X_temp).mean()
158
+
159
+ # Compute H-statistic
160
+ # H^2 = sum[PD_{ij} - PD_i - PD_j + PD_const]^2 / sum[PD_{ij}^2]
161
+
162
+ # For numerical stability, center everything
163
+ pd_const = pd_2d.mean()
164
+ pd_i_centered = pd_i - pd_const
165
+ pd_j_centered = pd_j - pd_const
166
+ pd_2d_centered = pd_2d - pd_const
167
+
168
+ # Interaction component: PD_{ij} - PD_i - PD_j
169
+ # Need to broadcast pd_i and pd_j to 2D
170
+ pd_i_broadcast = pd_i_centered[:, np.newaxis] # Shape: (grid_resolution, 1)
171
+ pd_j_broadcast = pd_j_centered[np.newaxis, :] # Shape: (1, grid_resolution)
172
+
173
+ interaction = pd_2d_centered - pd_i_broadcast - pd_j_broadcast
174
+
175
+ # H-statistic
176
+ numerator = np.sum(interaction**2)
177
+ denominator = np.sum(pd_2d_centered**2)
178
+
179
+ if denominator > 1e-10: # Avoid division by zero
180
+ h_squared = numerator / denominator
181
+ h_stat = np.sqrt(max(0, h_squared)) # Ensure non-negative
182
+ else:
183
+ h_stat = 0.0
184
+
185
+ h_results.append((feature_names_list[feat_i], feature_names_list[feat_j], float(h_stat)))
186
+
187
+ # Sort by H-statistic descending
188
+ h_results.sort(key=lambda x: x[2], reverse=True)
189
+
190
+ computation_time = time.time() - start_time
191
+
192
+ return {
193
+ "h_statistics": h_results,
194
+ "feature_names": feature_names,
195
+ "n_features": n_features,
196
+ "n_pairs_tested": len(h_results),
197
+ "n_samples_used": n_samples,
198
+ "grid_resolution": grid_resolution,
199
+ "computation_time": computation_time,
200
+ }
201
+
202
+
203
+ def compute_shap_interactions(
204
+ model: Any,
205
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
206
+ feature_names: list[str] | None = None,
207
+ _check_additivity: bool = False,
208
+ max_samples: int | None = None,
209
+ top_k: int | None = None,
210
+ ) -> dict[str, Any]:
211
+ """Compute SHAP interaction values for feature pairs.
212
+
213
+ SHAP interaction values decompose the SHAP value of each feature into:
214
+ - Main effect (the feature's individual contribution)
215
+ - Interaction effects (how the feature's impact changes with other features)
216
+
217
+ Parameters
218
+ ----------
219
+ model : Any
220
+ Trained tree-based model
221
+ X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
222
+ Feature matrix (n_samples, n_features)
223
+ feature_names : list[str] | None, default None
224
+ Feature names. If None, uses column names or f0, f1, ...
225
+ _check_additivity : bool, default False
226
+ Internal parameter (not used for interaction values)
227
+ max_samples : int | None, default None
228
+ Maximum samples to use (subsample if larger)
229
+ top_k : int | None, default None
230
+ Return only top K interactions by absolute magnitude
231
+
232
+ Returns
233
+ -------
234
+ dict[str, Any]
235
+ Dictionary with:
236
+ - interaction_matrix: (n_features, n_features) mean absolute interactions
237
+ - feature_names: List of feature names
238
+ - top_interactions: List of (feature_i, feature_j, mean_interaction) sorted by magnitude
239
+ - n_features: Number of features
240
+ - n_samples_used: Number of samples used
241
+ - computation_time: Time in seconds
242
+
243
+ Notes
244
+ -----
245
+ - Requires shap package (install with: pip install ml4t-diagnostic[ml])
246
+ - Only works with tree-based models (uses TreeExplainer)
247
+ - Interaction matrix is symmetric: interaction(i,j) = interaction(j,i)
248
+ """
249
+ start_time = time.time()
250
+
251
+ # Check shap availability
252
+ try:
253
+ import shap
254
+ except ImportError as e:
255
+ raise ImportError(
256
+ "SHAP is required for interaction values. "
257
+ "Install with: pip install ml4t-diagnostic[ml] "
258
+ "or: pip install shap>=0.43.0"
259
+ ) from e
260
+
261
+ # Convert input to numpy and extract feature names
262
+ if isinstance(X, pl.DataFrame):
263
+ if feature_names is None:
264
+ feature_names = X.columns
265
+ X_array = X.to_numpy()
266
+ elif isinstance(X, pd.DataFrame):
267
+ if feature_names is None:
268
+ feature_names = list(X.columns)
269
+ X_array = X.values
270
+ else: # numpy array
271
+ X_array = X
272
+ if feature_names is None:
273
+ feature_names = [f"f{i}" for i in range(X_array.shape[1])]
274
+
275
+ # Type assertion: feature_names is guaranteed to be set at this point
276
+ assert feature_names is not None, "feature_names should be set by this point"
277
+
278
+ n_total_samples, n_features = X_array.shape
279
+
280
+ # Subsample if needed
281
+ if max_samples is not None and n_total_samples > max_samples:
282
+ rng = np.random.RandomState(42)
283
+ indices = rng.choice(n_total_samples, size=max_samples, replace=False)
284
+ X_sample = X_array[indices]
285
+ n_samples_used = max_samples
286
+ else:
287
+ X_sample = X_array
288
+ n_samples_used = n_total_samples
289
+
290
+ # Compute SHAP interaction values using TreeExplainer
291
+ explainer = shap.TreeExplainer(model)
292
+ shap_interaction_values = explainer.shap_interaction_values(X_sample)
293
+
294
+ # Handle multi-output models (classification)
295
+ if isinstance(shap_interaction_values, list):
296
+ # List format: use positive class for binary, average for multiclass
297
+ if len(shap_interaction_values) == 2:
298
+ shap_interaction_values = shap_interaction_values[1]
299
+ else:
300
+ shap_interaction_values = np.mean(shap_interaction_values, axis=0)
301
+
302
+ # Check if we have a 4D array (n_samples, n_features, n_features, n_classes)
303
+ if shap_interaction_values.ndim == 4:
304
+ if shap_interaction_values.shape[-1] == 2:
305
+ # Binary classification: use positive class (index 1)
306
+ shap_interaction_values = shap_interaction_values[:, :, :, 1]
307
+ else:
308
+ # Multiclass: average absolute values across classes
309
+ shap_interaction_values = np.mean(np.abs(shap_interaction_values), axis=-1)
310
+
311
+ # Shape should now be: (n_samples, n_features, n_features)
312
+
313
+ # Compute mean absolute interaction matrix
314
+ interaction_matrix = np.mean(np.abs(shap_interaction_values), axis=0)
315
+
316
+ # Ensure 2D matrix (n_features, n_features)
317
+ if interaction_matrix.ndim != 2:
318
+ raise ValueError(
319
+ f"Interaction matrix should be 2D but got shape {interaction_matrix.shape}. "
320
+ f"Raw SHAP values shape: {shap_interaction_values.shape}"
321
+ )
322
+
323
+ # Extract top interactions (off-diagonal, upper triangle to avoid duplicates)
324
+ interactions_list = []
325
+ for i in range(n_features):
326
+ for j in range(i + 1, n_features): # Upper triangle only
327
+ mean_interaction = float(interaction_matrix[i, j])
328
+ interactions_list.append((feature_names[i], feature_names[j], mean_interaction))
329
+
330
+ # Sort by absolute interaction strength descending
331
+ interactions_list.sort(key=lambda x: abs(x[2]), reverse=True)
332
+
333
+ # Limit to top K if requested
334
+ if top_k is not None:
335
+ interactions_list = interactions_list[:top_k]
336
+
337
+ computation_time = time.time() - start_time
338
+
339
+ return {
340
+ "interaction_matrix": interaction_matrix,
341
+ "feature_names": feature_names,
342
+ "top_interactions": interactions_list,
343
+ "n_features": n_features,
344
+ "n_samples_used": n_samples_used,
345
+ "computation_time": computation_time,
346
+ }
347
+
348
+
349
+ def _generate_interaction_interpretation(
350
+ top_interactions: list[tuple[str, str]],
351
+ method_agreement: dict[tuple[str, str], float],
352
+ warnings: list[str],
353
+ n_consensus: int,
354
+ ) -> str:
355
+ """Generate human-readable interpretation of interaction analysis.
356
+
357
+ Parameters
358
+ ----------
359
+ top_interactions : list[tuple[str, str]]
360
+ Top feature pairs from consensus ranking
361
+ method_agreement : dict[tuple[str, str], float]
362
+ Pairwise correlations between method rankings
363
+ warnings : list[str]
364
+ List of potential issues detected
365
+ n_consensus : int
366
+ Number of interactions in top 10 across all methods
367
+
368
+ Returns
369
+ -------
370
+ str
371
+ Human-readable interpretation summary
372
+ """
373
+ lines = []
374
+
375
+ # Consensus interactions
376
+ if n_consensus > 0:
377
+ lines.append(
378
+ f"Strong consensus: {n_consensus} interactions rank in top 10 across all methods"
379
+ )
380
+ pairs_str = ", ".join([f"({a}, {b})" for a, b in top_interactions[:3]])
381
+ lines.append(f" Top consensus interactions: {pairs_str}")
382
+ else:
383
+ lines.append("Weak consensus: Different methods identify different important interactions")
384
+
385
+ # Method agreement
386
+ if method_agreement:
387
+ avg_agreement = float(np.mean(list(method_agreement.values())))
388
+ if avg_agreement > 0.7:
389
+ lines.append(f"High agreement between methods (avg correlation: {avg_agreement:.2f})")
390
+ elif avg_agreement > 0.5:
391
+ lines.append(
392
+ f"Moderate agreement between methods (avg correlation: {avg_agreement:.2f})"
393
+ )
394
+ else:
395
+ lines.append(
396
+ f"Low agreement between methods (avg correlation: {avg_agreement:.2f}) - investigate further"
397
+ )
398
+
399
+ # Warnings
400
+ if warnings:
401
+ lines.append("\nPotential Issues:")
402
+ for warning in warnings:
403
+ lines.append(f" - {warning}")
404
+
405
+ return "\n".join(lines)
406
+
407
+
408
+ def analyze_interactions(
409
+ model: Any,
410
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
411
+ y: Union[pl.Series, pd.Series, "NDArray[Any]"],
412
+ feature_pairs: list[tuple[str, str]] | None = None,
413
+ methods: list[str] | None = None,
414
+ n_quantiles: int = 5,
415
+ grid_resolution: int = 20,
416
+ max_samples: int = 200,
417
+ ) -> dict[str, Any]:
418
+ """Comprehensive feature interaction analysis comparing multiple methods.
419
+
420
+ **This is a TEAR SHEET function** - it runs multiple interaction detection methods
421
+ and generates a comparison report with consensus ranking and interpretation.
422
+
423
+ **Use Case**: "Which feature pairs interact in my model? Do different methods agree?"
424
+
425
+ This function replaces 100+ lines of manual comparison code by providing
426
+ integrated analysis showing:
427
+ - Individual method results (Conditional IC, H-statistic, SHAP interactions)
428
+ - Consensus ranking (interactions important across methods)
429
+ - Method agreement/disagreement analysis
430
+ - Auto-generated insights and warnings
431
+
432
+ Parameters
433
+ ----------
434
+ model : Any
435
+ Fitted model. Requirements vary by method:
436
+ - Conditional IC: Not used (analyzes feature correlations)
437
+ - H-statistic: Must have `predict()` method
438
+ - SHAP: Must be compatible with TreeExplainer
439
+ X : Union[pl.DataFrame, pd.DataFrame, np.ndarray]
440
+ Feature matrix (n_samples, n_features)
441
+ y : Union[pl.Series, pd.Series, np.ndarray]
442
+ Target values (n_samples,)
443
+ feature_pairs : list[tuple[str, str]] | None, default None
444
+ Specific feature pairs to analyze. If None, tests all pairs.
445
+ methods : list[str] | None, default ["conditional_ic", "h_statistic", "shap"]
446
+ Which methods to run.
447
+ n_quantiles : int, default 5
448
+ Number of quantile bins for Conditional IC
449
+ grid_resolution : int, default 20
450
+ Grid size for partial dependence in H-statistic
451
+ max_samples : int, default 200
452
+ Maximum samples for SHAP and H-statistic
453
+
454
+ Returns
455
+ -------
456
+ dict[str, Any]
457
+ Comprehensive analysis results:
458
+ - method_results: Dict of individual method outputs
459
+ - consensus_ranking: Feature pairs ranked by average rank across methods
460
+ - method_agreement: Spearman correlations between method rankings
461
+ - top_interactions_consensus: Pairs in top 10 for ALL methods
462
+ - warnings: Detected issues
463
+ - interpretation: Auto-generated summary
464
+ - methods_run: Methods successfully executed
465
+ - methods_failed: Failed methods with error messages
466
+
467
+ Raises
468
+ ------
469
+ ValueError
470
+ If all methods fail or no methods specified
471
+ """
472
+ if methods is None:
473
+ methods = ["conditional_ic", "h_statistic", "shap"]
474
+
475
+ if not methods:
476
+ raise ValueError("At least one method must be specified")
477
+
478
+ # Extract feature names if not provided
479
+ if isinstance(X, pl.DataFrame | pd.DataFrame):
480
+ feature_names = list(X.columns)
481
+ else:
482
+ # Generate numeric feature names
483
+ n_features = X.shape[1] if hasattr(X, "shape") else len(X[0])
484
+ feature_names = [f"f{i}" for i in range(n_features)]
485
+
486
+ # Determine feature pairs to analyze
487
+ if feature_pairs is None:
488
+ # Test all pairs
489
+ n_features = len(feature_names)
490
+ all_pairs = []
491
+ for i in range(n_features):
492
+ for j in range(i + 1, n_features):
493
+ all_pairs.append((feature_names[i], feature_names[j]))
494
+ feature_pairs = all_pairs
495
+ else:
496
+ # Validate provided pairs
497
+ feature_set = set(feature_names)
498
+ for pair in feature_pairs:
499
+ if len(pair) != 2:
500
+ raise ValueError(f"Feature pair must have exactly 2 elements: {pair}")
501
+ if pair[0] not in feature_set or pair[1] not in feature_set:
502
+ raise ValueError(
503
+ f"Feature pair contains unknown features: {pair}. Available features: {feature_names}"
504
+ )
505
+
506
+ # Run each method with try/except for optional dependencies and errors
507
+ results = {}
508
+ method_failures = []
509
+
510
+ if "conditional_ic" in methods:
511
+ try:
512
+ # For Conditional IC, we need to run it for each pair
513
+ ic_results: list[tuple[str, str, float | None]] = []
514
+ for feat_a, feat_b in feature_pairs:
515
+ # Extract columns
516
+ x_a: pl.Series | pd.Series | NDArray[Any]
517
+ x_b: pl.Series | pd.Series | NDArray[Any]
518
+ if isinstance(X, pl.DataFrame):
519
+ x_a = X[feat_a]
520
+ x_b = X[feat_b]
521
+ elif isinstance(X, pd.DataFrame):
522
+ x_a = X[feat_a]
523
+ x_b = X[feat_b]
524
+ else:
525
+ # numpy array - need to find indices
526
+ idx_a = feature_names.index(feat_a)
527
+ idx_b = feature_names.index(feat_b)
528
+ X_arr = cast("NDArray[Any]", X)
529
+ x_a = X_arr[:, idx_a]
530
+ x_b = X_arr[:, idx_b]
531
+
532
+ result = compute_conditional_ic(
533
+ feature_a=x_a,
534
+ feature_b=x_b,
535
+ forward_returns=y,
536
+ n_quantiles=n_quantiles,
537
+ )
538
+
539
+ # Extract interaction strength metric
540
+ ic_range = result.get("ic_range", 0.0)
541
+ ic_results.append((feat_a, feat_b, ic_range))
542
+
543
+ # Sort by IC range descending
544
+ ic_results.sort(key=lambda x: abs(x[2]) if x[2] is not None else 0.0, reverse=True)
545
+
546
+ results["conditional_ic"] = {
547
+ "top_interactions": ic_results,
548
+ "n_pairs_tested": len(ic_results),
549
+ }
550
+ except Exception as e:
551
+ method_failures.append(("conditional_ic", str(e)))
552
+
553
+ if "h_statistic" in methods:
554
+ try:
555
+ # Convert feature pairs to indices for h_statistic
556
+ pair_indices = []
557
+ for feat_a, feat_b in feature_pairs:
558
+ idx_a = feature_names.index(feat_a)
559
+ idx_b = feature_names.index(feat_b)
560
+ pair_indices.append((idx_a, idx_b))
561
+
562
+ results["h_statistic"] = compute_h_statistic(
563
+ model,
564
+ X,
565
+ feature_pairs=pair_indices,
566
+ feature_names=feature_names,
567
+ n_samples=max_samples,
568
+ grid_resolution=grid_resolution,
569
+ )
570
+ except Exception as e:
571
+ method_failures.append(("h_statistic", str(e)))
572
+
573
+ if "shap" in methods:
574
+ try:
575
+ shap_result = compute_shap_interactions(
576
+ model,
577
+ X,
578
+ feature_names=feature_names,
579
+ max_samples=max_samples,
580
+ )
581
+
582
+ # Filter to requested pairs if feature_pairs was specified
583
+ if feature_pairs is not None:
584
+ pair_set = set(feature_pairs) | {(b, a) for a, b in feature_pairs}
585
+ filtered_interactions = [
586
+ (a, b, score)
587
+ for a, b, score in shap_result["top_interactions"]
588
+ if (a, b) in pair_set or (b, a) in pair_set
589
+ ]
590
+ shap_result["top_interactions"] = filtered_interactions
591
+
592
+ results["shap"] = shap_result
593
+ except ImportError:
594
+ method_failures.append(
595
+ (
596
+ "shap",
597
+ "shap library not installed. Install with: pip install ml4t-diagnostic[ml]",
598
+ )
599
+ )
600
+ except Exception as e:
601
+ method_failures.append(("shap", str(e)))
602
+
603
+ # Check if at least one method succeeded
604
+ if not results:
605
+ error_msg = "All methods failed:\n" + "\n".join(
606
+ f" - {method}: {error}" for method, error in method_failures
607
+ )
608
+ raise ValueError(error_msg)
609
+
610
+ # 2. Compute consensus ranking
611
+ rankings: dict[str, NDArray[Any]] = {}
612
+ for method_name, result in results.items():
613
+ # Get interaction scores for this method
614
+ method_interactions: list[tuple[str, str, float]]
615
+ if "top_interactions" in result:
616
+ method_interactions = cast(list[tuple[str, str, float]], result["top_interactions"])
617
+ elif "h_statistics" in result:
618
+ method_interactions = cast(list[tuple[str, str, float]], result["h_statistics"])
619
+ else:
620
+ continue
621
+
622
+ # Create a mapping from pair to rank
623
+ pair_to_rank: dict[tuple[str, str], int] = {}
624
+ for rank_idx, interaction_tuple in enumerate(method_interactions):
625
+ feat_a_int, feat_b_int = str(interaction_tuple[0]), str(interaction_tuple[1])
626
+ pair_key = (min(feat_a_int, feat_b_int), max(feat_a_int, feat_b_int))
627
+ pair_to_rank[pair_key] = rank_idx
628
+
629
+ # Map all requested pairs to ranks (handle missing pairs)
630
+ ranks_array: list[int] = []
631
+ for feat_a, feat_b in feature_pairs:
632
+ pair_key = (min(feat_a, feat_b), max(feat_a, feat_b))
633
+ rank_val = pair_to_rank.get(pair_key, len(method_interactions))
634
+ ranks_array.append(rank_val)
635
+
636
+ rankings[method_name] = np.array(ranks_array)
637
+
638
+ # Average ranks across methods
639
+ avg_ranks = np.mean(list(rankings.values()), axis=0)
640
+
641
+ # Create consensus ranking with scores from each method
642
+ consensus_ranking: list[tuple[str, str, float, dict[str, float]]] = []
643
+ for idx, avg_rank in enumerate(avg_ranks):
644
+ feat_a, feat_b = feature_pairs[idx]
645
+ pair_tuple: tuple[str, str] = (min(feat_a, feat_b), max(feat_a, feat_b))
646
+
647
+ # Collect scores from each method
648
+ scores_dict: dict[str, float] = {}
649
+ for method_name, result in results.items():
650
+ method_ints: list[tuple[str, str, float]]
651
+ if "top_interactions" in result:
652
+ method_ints = cast(list[tuple[str, str, float]], result["top_interactions"])
653
+ elif "h_statistics" in result:
654
+ method_ints = cast(list[tuple[str, str, float]], result["h_statistics"])
655
+ else:
656
+ continue
657
+
658
+ for int_tuple in method_ints:
659
+ check_pair = (
660
+ min(str(int_tuple[0]), str(int_tuple[1])),
661
+ max(str(int_tuple[0]), str(int_tuple[1])),
662
+ )
663
+ if check_pair == pair_tuple:
664
+ scores_dict[method_name] = float(int_tuple[2])
665
+ break
666
+
667
+ consensus_ranking.append((feat_a, feat_b, float(avg_rank), scores_dict))
668
+
669
+ # Sort by average rank
670
+ consensus_ranking.sort(key=lambda x: x[2])
671
+
672
+ # 3. Compute method agreement (Spearman correlation between rankings)
673
+ method_agreement = {}
674
+ method_names = list(rankings.keys())
675
+ for i, m1 in enumerate(method_names):
676
+ for m2 in method_names[i + 1 :]:
677
+ corr, _ = spearmanr(rankings[m1], rankings[m2])
678
+ method_agreement[(m1, m2)] = float(corr)
679
+
680
+ # 4. Identify consensus top interactions (top 10 in all methods)
681
+ top_n = 10
682
+ top_interactions_by_method: dict[str, set[tuple[str, str]]] = {}
683
+ for method_name, result in results.items():
684
+ method_ints_list: list[tuple[str, str, float]]
685
+ if "top_interactions" in result:
686
+ method_ints_list = cast(list[tuple[str, str, float]], result["top_interactions"])
687
+ elif "h_statistics" in result:
688
+ method_ints_list = cast(list[tuple[str, str, float]], result["h_statistics"])
689
+ else:
690
+ continue
691
+
692
+ method_top_pairs: list[tuple[str, str]] = []
693
+ for int_entry in method_ints_list[:top_n]:
694
+ pair_sorted: tuple[str, str] = (
695
+ min(str(int_entry[0]), str(int_entry[1])),
696
+ max(str(int_entry[0]), str(int_entry[1])),
697
+ )
698
+ method_top_pairs.append(pair_sorted)
699
+ top_interactions_by_method[method_name] = set(method_top_pairs)
700
+
701
+ if top_interactions_by_method:
702
+ consensus_top_pairs = set.intersection(*top_interactions_by_method.values())
703
+ else:
704
+ consensus_top_pairs = set()
705
+
706
+ consensus_top_list = list(consensus_top_pairs)
707
+
708
+ # 5. Generate warnings
709
+ warnings = []
710
+
711
+ # Warning: Disagreement between specific methods
712
+ if "conditional_ic" in results and "h_statistic" in results:
713
+ ic_interactions: list[tuple[str, str, float]]
714
+ if "top_interactions" in results["conditional_ic"]:
715
+ ic_interactions = cast(
716
+ list[tuple[str, str, float]], results["conditional_ic"]["top_interactions"]
717
+ )
718
+ else:
719
+ ic_interactions = []
720
+
721
+ h_interactions: list[tuple[str, str, float]] = cast(
722
+ list[tuple[str, str, float]], results["h_statistic"].get("h_statistics", [])
723
+ )
724
+
725
+ ic_top: set[tuple[str, str]] = {
726
+ (min(str(x[0]), str(x[1])), max(str(x[0]), str(x[1]))) for x in ic_interactions[:5]
727
+ }
728
+ h_top: set[tuple[str, str]] = {
729
+ (min(str(x[0]), str(x[1])), max(str(x[0]), str(x[1]))) for x in h_interactions[:5]
730
+ }
731
+
732
+ disagreement = ic_top - h_top
733
+ if disagreement:
734
+ pairs_str = ", ".join([f"({a}, {b})" for a, b in disagreement])
735
+ warnings.append(
736
+ f"Pairs {pairs_str} rank high in Conditional IC but not H-statistic - "
737
+ "possible regime-specific interaction (time-varying)"
738
+ )
739
+
740
+ # Warning: Low agreement between methods
741
+ if method_agreement:
742
+ min_agreement = min(method_agreement.values())
743
+ if min_agreement < 0.5:
744
+ warnings.append(
745
+ f"Low agreement between methods (min correlation: {min_agreement:.2f}) - "
746
+ "results may be unreliable or methods capture different interaction types"
747
+ )
748
+
749
+ # Add method failures to warnings
750
+ if method_failures:
751
+ for method, error in method_failures:
752
+ warnings.append(f"Method '{method}' failed: {error}")
753
+
754
+ # 6. Generate interpretation
755
+ top_pairs = [(a, b) for a, b, _, _ in consensus_ranking[:10]]
756
+ interpretation = _generate_interaction_interpretation(
757
+ top_pairs,
758
+ method_agreement,
759
+ warnings,
760
+ len(consensus_top_list),
761
+ )
762
+
763
+ return {
764
+ "method_results": results,
765
+ "consensus_ranking": consensus_ranking,
766
+ "method_agreement": method_agreement,
767
+ "top_interactions_consensus": consensus_top_list,
768
+ "warnings": warnings,
769
+ "interpretation": interpretation,
770
+ "methods_run": list(results.keys()),
771
+ "methods_failed": method_failures,
772
+ }