ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,504 @@
1
+ """Interaction data extraction for visualization layer.
2
+
3
+ Extracts comprehensive visualization data from feature interaction analysis results.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from datetime import datetime
9
+ from typing import Any, cast
10
+
11
+ import numpy as np
12
+
13
+ from .types import (
14
+ FeatureInteractionData,
15
+ InteractionMatrixData,
16
+ InteractionVizData,
17
+ LLMContextData,
18
+ NetworkGraphData,
19
+ )
20
+ from .validation import _validate_matrix_feature_alignment
21
+
22
+
23
+ def extract_interaction_viz_data(
24
+ interaction_results: dict[str, Any],
25
+ importance_results: dict[str, Any] | None = None,
26
+ n_top_partners: int = 5,
27
+ cluster_threshold: float = 0.3,
28
+ include_llm_context: bool = True,
29
+ ) -> InteractionVizData:
30
+ """Extract comprehensive visualization data from interaction analysis results.
31
+
32
+ This function transforms raw SHAP interaction results into structured data
33
+ optimized for rich interactive visualization, including per-feature summaries,
34
+ network graph data, interaction matrices, and auto-generated insights.
35
+
36
+ Parameters
37
+ ----------
38
+ interaction_results : dict
39
+ Results from compute_shap_interactions() containing:
40
+ - 'interaction_matrix': DataFrame with pairwise interactions
41
+ - 'feature_names': list of feature names
42
+ - 'shap_values': raw SHAP values (optional)
43
+ - 'shap_interaction_values': raw interaction values (optional)
44
+ importance_results : dict, optional
45
+ Optional importance results to cross-reference for node sizing.
46
+ If provided, will use consensus ranking to size network nodes.
47
+ n_top_partners : int, default=5
48
+ Number of top interaction partners to include per feature.
49
+ cluster_threshold : float, default=0.3
50
+ Minimum interaction strength to consider for clustering.
51
+ Features with interactions above this threshold are clustered.
52
+ include_llm_context : bool, default=True
53
+ Whether to generate auto-narratives for LLM consumption.
54
+
55
+ Returns
56
+ -------
57
+ InteractionVizData
58
+ Complete structured data package with:
59
+ - Per-feature interaction summaries
60
+ - Network graph data (nodes, edges, clusters)
61
+ - Interaction matrix data
62
+ - Strength distribution statistics
63
+ - Auto-generated LLM narratives
64
+
65
+ Examples
66
+ --------
67
+ >>> from ml4t.diagnostic.evaluation import compute_shap_interactions
68
+ >>> from ml4t.diagnostic.visualization.data_extraction import extract_interaction_viz_data
69
+ >>>
70
+ >>> # Compute interactions
71
+ >>> interaction_results = compute_shap_interactions(model, X, y)
72
+ >>>
73
+ >>> # Extract visualization data
74
+ >>> viz_data = extract_interaction_viz_data(interaction_results)
75
+ >>>
76
+ >>> # Access different views
77
+ >>> print(viz_data['summary']['strongest_interaction'])
78
+ >>> print(viz_data['per_feature']['momentum']['top_partners'][:3])
79
+ >>> print(viz_data['network_graph']['nodes'])
80
+ >>> print(viz_data['llm_context']['key_insights'])
81
+
82
+ Notes
83
+ -----
84
+ - Network graph data is pre-computed for custom rendering
85
+ - Clustering identifies groups of strongly interacting features
86
+ - Per-feature summaries enable drill-down dashboards
87
+ - Cross-referencing with importance results enables better node sizing
88
+ """
89
+ # Extract basic info
90
+ interaction_matrix_df = interaction_results.get("interaction_matrix")
91
+ feature_names = interaction_results.get("feature_names", [])
92
+
93
+ if interaction_matrix_df is None:
94
+ raise ValueError("interaction_results must contain 'interaction_matrix'")
95
+
96
+ # Convert to numpy for easier manipulation
97
+ if hasattr(interaction_matrix_df, "to_numpy"):
98
+ interaction_matrix = interaction_matrix_df.to_numpy()
99
+ else:
100
+ interaction_matrix = np.array(interaction_matrix_df)
101
+
102
+ # Validate matrix dimensions match feature names
103
+ _validate_matrix_feature_alignment(interaction_matrix, feature_names)
104
+
105
+ n_features = len(feature_names)
106
+
107
+ # Build summary statistics
108
+ summary = _build_interaction_summary(interaction_matrix, feature_names)
109
+
110
+ # Build per-feature interaction data
111
+ per_feature = _build_per_feature_interactions(interaction_matrix, feature_names, n_top_partners)
112
+
113
+ # Build network graph data
114
+ network_graph = _build_network_graph(
115
+ interaction_matrix, feature_names, importance_results, cluster_threshold
116
+ )
117
+
118
+ # Build matrix data
119
+ matrix_data = _build_interaction_matrix_data(interaction_matrix, feature_names)
120
+
121
+ # Build strength distribution
122
+ strength_distribution = _build_strength_distribution(interaction_matrix)
123
+
124
+ # Build metadata
125
+ metadata = {
126
+ "n_features": n_features,
127
+ "n_interactions": int(n_features * (n_features - 1) / 2),
128
+ "analysis_timestamp": datetime.now().isoformat(),
129
+ "cluster_threshold": cluster_threshold,
130
+ "n_top_partners": n_top_partners,
131
+ }
132
+
133
+ # Generate LLM context
134
+ llm_context: LLMContextData = {
135
+ "summary_narrative": "",
136
+ "key_insights": [],
137
+ "recommendations": [],
138
+ "caveats": [],
139
+ "analysis_quality": "medium",
140
+ }
141
+ if include_llm_context:
142
+ llm_context = _generate_interaction_llm_context(
143
+ summary, per_feature, network_graph, strength_distribution
144
+ )
145
+
146
+ return InteractionVizData(
147
+ summary=summary,
148
+ per_feature=per_feature,
149
+ network_graph=network_graph,
150
+ interaction_matrix=matrix_data,
151
+ strength_distribution=strength_distribution,
152
+ metadata=metadata,
153
+ llm_context=llm_context,
154
+ )
155
+
156
+
157
+ # =============================================================================
158
+ # Interaction Analysis Helpers
159
+ # =============================================================================
160
+
161
+
162
+ def _build_interaction_summary(
163
+ interaction_matrix: np.ndarray, feature_names: list[str]
164
+ ) -> dict[str, Any]:
165
+ """Build high-level summary statistics for interactions."""
166
+ n_features = len(feature_names)
167
+
168
+ # Get upper triangle (exclude diagonal)
169
+ triu_indices = np.triu_indices(n_features, k=1)
170
+ interaction_values = interaction_matrix[triu_indices]
171
+
172
+ # Find strongest interaction
173
+ abs_values = np.abs(interaction_values)
174
+ max_idx = np.argmax(abs_values)
175
+ max_interaction = float(interaction_values[max_idx])
176
+
177
+ # Get feature pair for strongest interaction
178
+ i, j = triu_indices[0][max_idx], triu_indices[1][max_idx]
179
+ strongest_pair = (feature_names[i], feature_names[j])
180
+
181
+ # Compute distribution statistics
182
+ mean_interaction = float(np.mean(abs_values))
183
+ median_interaction = float(np.median(abs_values))
184
+ std_interaction = float(np.std(abs_values))
185
+
186
+ # Identify features with strongest overall interactions
187
+ total_interactions = np.sum(np.abs(interaction_matrix), axis=1)
188
+ top_idx = np.argmax(total_interactions)
189
+ most_interactive_feature = feature_names[top_idx]
190
+
191
+ return {
192
+ "n_features": n_features,
193
+ "n_interactions": len(interaction_values),
194
+ "strongest_interaction": max_interaction,
195
+ "strongest_pair": strongest_pair,
196
+ "mean_interaction": mean_interaction,
197
+ "median_interaction": median_interaction,
198
+ "std_interaction": std_interaction,
199
+ "most_interactive_feature": most_interactive_feature,
200
+ "max_total_interaction": float(total_interactions[top_idx]),
201
+ }
202
+
203
+
204
+ def _build_per_feature_interactions(
205
+ interaction_matrix: np.ndarray, feature_names: list[str], n_top_partners: int = 5
206
+ ) -> dict[str, FeatureInteractionData]:
207
+ """Build per-feature interaction summaries."""
208
+ per_feature: dict[str, FeatureInteractionData] = {}
209
+ n_features = len(feature_names)
210
+
211
+ for i, feature_name in enumerate(feature_names):
212
+ # Get all interactions for this feature
213
+ interactions = interaction_matrix[i, :]
214
+
215
+ # Exclude self-interaction
216
+ partner_indices = [j for j in range(n_features) if j != i]
217
+ partner_interactions = [(feature_names[j], float(interactions[j])) for j in partner_indices]
218
+
219
+ # Sort by absolute interaction strength
220
+ partner_interactions.sort(key=lambda x: abs(x[1]), reverse=True)
221
+
222
+ # Get top N partners
223
+ top_partners = partner_interactions[:n_top_partners]
224
+
225
+ # Total interaction strength
226
+ total_strength = float(np.sum(np.abs(interactions)))
227
+
228
+ # Generate interpretation
229
+ interpretation = _generate_interaction_interpretation(feature_name, top_partners)
230
+
231
+ per_feature[feature_name] = FeatureInteractionData(
232
+ feature_name=feature_name,
233
+ top_partners=top_partners,
234
+ total_interaction_strength=total_strength,
235
+ cluster_id=None, # Will be filled by clustering
236
+ interpretation=interpretation,
237
+ )
238
+
239
+ return per_feature
240
+
241
+
242
+ def _build_network_graph(
243
+ interaction_matrix: np.ndarray,
244
+ feature_names: list[str],
245
+ importance_results: dict[str, Any] | None,
246
+ cluster_threshold: float,
247
+ ) -> NetworkGraphData:
248
+ """Build network graph data (nodes, edges, clusters)."""
249
+ n_features = len(feature_names)
250
+
251
+ # Build nodes
252
+ nodes = []
253
+ for i, feature_name in enumerate(feature_names):
254
+ # Node importance (for sizing) - use importance if available
255
+ if importance_results and "consensus_ranking" in importance_results:
256
+ consensus_ranking = importance_results["consensus_ranking"]
257
+ if feature_name in consensus_ranking:
258
+ rank = consensus_ranking.index(feature_name) + 1
259
+ # Higher rank = smaller number = more important = larger node
260
+ node_importance = 1.0 / rank
261
+ else:
262
+ node_importance = 0.1
263
+ else:
264
+ # Use total interaction strength as proxy
265
+ node_importance = float(np.sum(np.abs(interaction_matrix[i, :])))
266
+
267
+ nodes.append(
268
+ {
269
+ "id": feature_name,
270
+ "label": feature_name,
271
+ "importance": node_importance,
272
+ "total_interaction": float(np.sum(np.abs(interaction_matrix[i, :]))),
273
+ }
274
+ )
275
+
276
+ # Build edges (only upper triangle to avoid duplicates)
277
+ edges = []
278
+ for i in range(n_features):
279
+ for j in range(i + 1, n_features):
280
+ interaction_value = float(interaction_matrix[i, j])
281
+ if abs(interaction_value) > 0: # Include all non-zero interactions
282
+ edges.append(
283
+ {
284
+ "source": feature_names[i],
285
+ "target": feature_names[j],
286
+ "weight": interaction_value,
287
+ "abs_weight": abs(interaction_value),
288
+ }
289
+ )
290
+
291
+ # Sort edges by absolute weight
292
+ edges.sort(key=lambda e: cast(float, e["abs_weight"]), reverse=True)
293
+
294
+ # Perform simple clustering based on strong interactions
295
+ clusters = _detect_interaction_clusters(interaction_matrix, feature_names, cluster_threshold)
296
+
297
+ return NetworkGraphData(nodes=nodes, edges=edges, clusters=clusters)
298
+
299
+
300
+ def _build_interaction_matrix_data(
301
+ interaction_matrix: np.ndarray, feature_names: list[str]
302
+ ) -> InteractionMatrixData:
303
+ """Build matrix data for heatmap visualization."""
304
+ # Convert to list of lists for JSON serialization
305
+ matrix_list = interaction_matrix.tolist()
306
+
307
+ # Compute statistics
308
+ triu_indices = np.triu_indices(len(feature_names), k=1)
309
+ interaction_values = interaction_matrix[triu_indices]
310
+
311
+ max_interaction = float(np.max(np.abs(interaction_values)))
312
+ mean_interaction = float(np.mean(np.abs(interaction_values)))
313
+
314
+ return InteractionMatrixData(
315
+ features=feature_names,
316
+ matrix=matrix_list,
317
+ max_interaction=max_interaction,
318
+ mean_interaction=mean_interaction,
319
+ )
320
+
321
+
322
+ def _build_strength_distribution(interaction_matrix: np.ndarray) -> dict[str, Any]:
323
+ """Build distribution statistics for interaction strengths."""
324
+ n_features = interaction_matrix.shape[0]
325
+ triu_indices = np.triu_indices(n_features, k=1)
326
+ interaction_values = interaction_matrix[triu_indices]
327
+ abs_values = np.abs(interaction_values)
328
+
329
+ # Compute percentiles
330
+ percentiles = [10, 25, 50, 75, 90, 95, 99]
331
+ percentile_values = {f"p{p}": float(np.percentile(abs_values, p)) for p in percentiles}
332
+
333
+ # Binning for histogram
334
+ hist, bin_edges = np.histogram(abs_values, bins=20)
335
+
336
+ return {
337
+ "mean": float(np.mean(abs_values)),
338
+ "median": float(np.median(abs_values)),
339
+ "std": float(np.std(abs_values)),
340
+ "min": float(np.min(abs_values)),
341
+ "max": float(np.max(abs_values)),
342
+ "percentiles": percentile_values,
343
+ "histogram": {"counts": hist.tolist(), "bin_edges": bin_edges.tolist()},
344
+ }
345
+
346
+
347
+ def _detect_interaction_clusters(
348
+ interaction_matrix: np.ndarray, feature_names: list[str], threshold: float
349
+ ) -> list[list[str]]:
350
+ """Detect clusters of strongly interacting features using simple thresholding.
351
+
352
+ This is a basic clustering approach based on connected components in the
353
+ interaction graph. More sophisticated methods could be added later.
354
+ """
355
+ n_features = len(feature_names)
356
+
357
+ # Create adjacency matrix based on threshold
358
+ adj_matrix = np.abs(interaction_matrix) > threshold
359
+ np.fill_diagonal(adj_matrix, False) # No self-loops
360
+
361
+ # Find connected components (simple DFS)
362
+ visited = [False] * n_features
363
+ clusters = []
364
+
365
+ def dfs(node: int, cluster: list[int]) -> None:
366
+ visited[node] = True
367
+ cluster.append(node)
368
+ for neighbor in range(n_features):
369
+ if adj_matrix[node, neighbor] and not visited[neighbor]:
370
+ dfs(neighbor, cluster)
371
+
372
+ for i in range(n_features):
373
+ if not visited[i]:
374
+ cluster_indices: list[int] = []
375
+ dfs(i, cluster_indices)
376
+ if len(cluster_indices) > 1: # Only include clusters with >1 feature
377
+ clusters.append([feature_names[idx] for idx in cluster_indices])
378
+
379
+ return clusters
380
+
381
+
382
+ def _generate_interaction_interpretation(
383
+ feature_name: str, top_partners: list[tuple[str, float]]
384
+ ) -> str:
385
+ """Generate auto-interpretation for a single feature's interactions."""
386
+ if not top_partners:
387
+ return f"'{feature_name}' has no significant interactions."
388
+
389
+ # Get top 3 for narrative
390
+ top_3 = top_partners[:3]
391
+ partner_str = ", ".join([f"'{p[0]}' ({p[1]:.3f})" for p in top_3])
392
+
393
+ return (
394
+ f"'{feature_name}' shows strongest interactions with {partner_str}. "
395
+ f"These interaction effects suggest the feature's predictive power "
396
+ f"depends on the values of these partner features."
397
+ )
398
+
399
+
400
+ def _generate_interaction_llm_context(
401
+ summary: dict[str, Any],
402
+ _per_feature: dict[str, FeatureInteractionData],
403
+ network_graph: NetworkGraphData,
404
+ strength_distribution: dict[str, Any],
405
+ ) -> LLMContextData:
406
+ """Generate auto-narratives for interaction analysis."""
407
+ n_features = summary["n_features"]
408
+ n_interactions = summary["n_interactions"]
409
+ strongest_pair = summary["strongest_pair"]
410
+ strongest_value = summary["strongest_interaction"]
411
+ most_interactive = summary["most_interactive_feature"]
412
+
413
+ # Build summary narrative
414
+ summary_narrative = (
415
+ f"This interaction analysis examined {n_features} features, identifying "
416
+ f"{n_interactions} pairwise interactions. "
417
+ )
418
+
419
+ summary_narrative += (
420
+ f"The strongest interaction ({strongest_value:.3f}) occurs between "
421
+ f"'{strongest_pair[0]}' and '{strongest_pair[1]}'. "
422
+ )
423
+
424
+ if network_graph["clusters"]:
425
+ n_clusters = len(network_graph["clusters"])
426
+ summary_narrative += (
427
+ f"Cluster analysis identified {n_clusters} group(s) of strongly interacting features. "
428
+ )
429
+
430
+ # Key insights
431
+ key_insights = []
432
+
433
+ # Insight 1: Strongest interaction
434
+ key_insights.append(
435
+ f"Strongest interaction: {strongest_pair[0]} <-> {strongest_pair[1]} (strength: {strongest_value:.3f})"
436
+ )
437
+
438
+ # Insight 2: Most interactive feature
439
+ key_insights.append(
440
+ f"Most interactive feature: '{most_interactive}' (total interaction: {summary['max_total_interaction']:.3f})"
441
+ )
442
+
443
+ # Insight 3: Distribution characteristics
444
+ mean_strength = strength_distribution["mean"]
445
+ median_strength = strength_distribution["median"]
446
+ if mean_strength > median_strength * 1.5:
447
+ key_insights.append(
448
+ f"Interaction strength distribution is right-skewed "
449
+ f"(mean: {mean_strength:.3f}, median: {median_strength:.3f}) - "
450
+ "a few strong interactions dominate"
451
+ )
452
+
453
+ # Insight 4: Clustering
454
+ if network_graph["clusters"]:
455
+ largest_cluster = list(max(network_graph["clusters"], key=len)) # type: ignore[arg-type]
456
+ key_insights.append(
457
+ f"Largest interaction cluster has {len(largest_cluster)} features: "
458
+ f"{', '.join(largest_cluster[:5])}" + ("..." if len(largest_cluster) > 5 else "")
459
+ )
460
+
461
+ # Recommendations
462
+ recommendations = []
463
+
464
+ # Rec 1: Focus on strong interactions
465
+ recommendations.append(
466
+ f"Investigate the {strongest_pair[0]}/{strongest_pair[1]} interaction further. "
467
+ "Strong interactions suggest conditional effects or non-linear relationships."
468
+ )
469
+
470
+ # Rec 2: Feature engineering
471
+ if network_graph["clusters"]:
472
+ recommendations.append(
473
+ "Consider creating interaction features (products, ratios) for clustered "
474
+ "feature groups to capture non-linear effects explicitly."
475
+ )
476
+
477
+ # Rec 3: Model selection
478
+ recommendations.append(
479
+ "Tree-based models and neural networks can capture these interactions naturally. "
480
+ "Linear models may benefit from explicit interaction terms."
481
+ )
482
+
483
+ # Caveats
484
+ caveats = [
485
+ "SHAP interactions measure feature contribution interactions, not statistical "
486
+ "correlations. High interaction doesn't imply high correlation.",
487
+ "Interaction values are model-specific and depend on the underlying model structure.",
488
+ ]
489
+
490
+ # Determine quality
491
+ if n_features >= 5 and summary["max_total_interaction"] > 0.1:
492
+ analysis_quality = "high"
493
+ elif n_features >= 3:
494
+ analysis_quality = "medium"
495
+ else:
496
+ analysis_quality = "low"
497
+
498
+ return LLMContextData(
499
+ summary_narrative=summary_narrative,
500
+ key_insights=key_insights,
501
+ recommendations=recommendations,
502
+ caveats=caveats,
503
+ analysis_quality=analysis_quality,
504
+ )
@@ -0,0 +1,113 @@
1
+ """Type definitions for data extraction.
2
+
3
+ TypedDict classes for structured visualization data packages.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import Any, TypedDict
9
+
10
+
11
+ class MethodImportanceData(TypedDict, total=False):
12
+ """Importance data for a single method."""
13
+
14
+ importances: dict[str, float] # feature_name -> importance_score
15
+ ranking: list[str] # Features sorted by importance
16
+ std: dict[str, float] | None # Standard deviation if available (PFI)
17
+ confidence_intervals: dict[str, tuple[float, float]] | None # 95% CI if available
18
+ raw_values: list[dict[str, float]] | None # Per-repeat values (PFI)
19
+ metadata: dict[str, Any] # Method-specific metadata
20
+
21
+
22
+ class FeatureDetailData(TypedDict):
23
+ """Complete data for a single feature across all analyses."""
24
+
25
+ consensus_rank: int # Overall ranking
26
+ consensus_score: float # Consensus importance score
27
+ method_ranks: dict[str, int] # Method name -> rank in that method
28
+ method_scores: dict[str, float] # Method name -> importance score
29
+ method_stds: dict[str, float] # Method name -> std dev (if available)
30
+ agreement_level: str # 'high', 'medium', 'low'
31
+ stability_score: float # 0-1, higher = more stable
32
+ interpretation: str # Auto-generated interpretation
33
+
34
+
35
+ class MethodComparisonData(TypedDict):
36
+ """Method agreement and comparison metrics."""
37
+
38
+ correlation_matrix: list[list[float]] # Method x Method correlation matrix
39
+ correlation_methods: list[str] # Method names for matrix axes
40
+ rank_differences: dict[
41
+ tuple[str, str], dict[str, int]
42
+ ] # (method1, method2) -> {feature: rank_diff}
43
+ agreement_summary: dict[str, float] # Pairwise correlations as dict
44
+
45
+
46
+ class UncertaintyData(TypedDict):
47
+ """Uncertainty and stability metrics."""
48
+
49
+ method_stability: dict[str, float] # Method -> stability score (0-1)
50
+ rank_stability: dict[str, list[int]] # Feature -> list of ranks across bootstraps
51
+ confidence_intervals: dict[str, dict[str, tuple[float, float]]] # Method -> {feature: (lo, hi)}
52
+ coefficient_of_variation: dict[str, dict[str, float]] # Method -> {feature: CV}
53
+
54
+
55
+ class LLMContextData(TypedDict):
56
+ """Structured data for LLM interpretation."""
57
+
58
+ summary_narrative: str # High-level summary in natural language
59
+ key_insights: list[str] # Bullet points of findings
60
+ recommendations: list[str] # Actionable recommendations
61
+ caveats: list[str] # Limitations and warnings
62
+ analysis_quality: str # 'high', 'medium', 'low'
63
+
64
+
65
+ class ImportanceVizData(TypedDict):
66
+ """Complete visualization data package for importance analysis."""
67
+
68
+ summary: dict[str, Any] # High-level metrics
69
+ per_method: dict[str, MethodImportanceData] # Method name -> detailed data
70
+ per_feature: dict[str, FeatureDetailData] # Feature name -> aggregated view
71
+ uncertainty: UncertaintyData # Stability and confidence metrics
72
+ method_comparison: MethodComparisonData # Cross-method analysis
73
+ metadata: dict[str, Any] # Context information
74
+ llm_context: LLMContextData # LLM-friendly narratives
75
+
76
+
77
+ class FeatureInteractionData(TypedDict):
78
+ """Interaction data for a single feature."""
79
+
80
+ feature_name: str
81
+ top_partners: list[tuple[str, float]] # (partner_feature, interaction_strength)
82
+ total_interaction_strength: float # Sum of absolute interactions
83
+ cluster_id: int | None # ID of interaction cluster (if clustering performed)
84
+ interpretation: str # Auto-generated interpretation
85
+
86
+
87
+ class NetworkGraphData(TypedDict):
88
+ """Network graph representation of interactions."""
89
+
90
+ nodes: list[dict[str, Any]] # [{id: str, label: str, importance: float, ...}]
91
+ edges: list[dict[str, Any]] # [{source: str, target: str, weight: float, ...}]
92
+ clusters: list[list[str]] # List of feature clusters based on interactions
93
+
94
+
95
+ class InteractionMatrixData(TypedDict):
96
+ """Matrix representation of pairwise interactions."""
97
+
98
+ features: list[str] # Ordered feature names
99
+ matrix: list[list[float]] # Symmetric interaction matrix
100
+ max_interaction: float # Maximum interaction value
101
+ mean_interaction: float # Mean interaction strength
102
+
103
+
104
+ class InteractionVizData(TypedDict):
105
+ """Complete visualization data package for interaction analysis."""
106
+
107
+ summary: dict[str, Any] # High-level metrics
108
+ per_feature: dict[str, FeatureInteractionData] # Feature -> interaction details
109
+ network_graph: NetworkGraphData # Graph visualization data
110
+ interaction_matrix: InteractionMatrixData # Matrix visualization data
111
+ strength_distribution: dict[str, Any] # Distribution of interaction strengths
112
+ metadata: dict[str, Any] # Context information
113
+ llm_context: LLMContextData # LLM-friendly narratives
@@ -0,0 +1,66 @@
1
+ """Validation helpers for data extraction.
2
+
3
+ Provides length and dimension validation for extracted visualization data.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import numpy as np
9
+
10
+
11
+ def _validate_lengths_match(
12
+ *arrays: tuple[str, list | np.ndarray],
13
+ ) -> None:
14
+ """Validate that all provided arrays have matching lengths.
15
+
16
+ Parameters
17
+ ----------
18
+ *arrays : tuple[str, list | np.ndarray]
19
+ Tuples of (name, array) to validate.
20
+
21
+ Raises
22
+ ------
23
+ ValueError
24
+ If arrays have different lengths.
25
+ """
26
+ if not arrays:
27
+ return
28
+
29
+ lengths = [(name, len(arr)) for name, arr in arrays]
30
+ unique_lengths = {length for _, length in lengths}
31
+
32
+ if len(unique_lengths) > 1:
33
+ length_info = ", ".join(f"{name}={length}" for name, length in lengths)
34
+ raise ValueError(
35
+ f"Length mismatch in data extraction: {length_info}. "
36
+ "All arrays must have the same length for consistent visualization."
37
+ )
38
+
39
+
40
+ def _validate_matrix_feature_alignment(matrix: np.ndarray, feature_names: list[str]) -> None:
41
+ """Validate that interaction matrix dimensions match feature names.
42
+
43
+ Parameters
44
+ ----------
45
+ matrix : np.ndarray
46
+ Square interaction matrix.
47
+ feature_names : list[str]
48
+ Feature names for matrix axes.
49
+
50
+ Raises
51
+ ------
52
+ ValueError
53
+ If matrix is not square or dimensions don't match feature count.
54
+ """
55
+ n_features = len(feature_names)
56
+ if matrix.ndim != 2:
57
+ raise ValueError(
58
+ f"Interaction matrix must be 2D, got {matrix.ndim}D with shape {matrix.shape}"
59
+ )
60
+ if matrix.shape[0] != matrix.shape[1]:
61
+ raise ValueError(f"Interaction matrix must be square, got shape {matrix.shape}")
62
+ if matrix.shape[0] != n_features:
63
+ raise ValueError(
64
+ f"Interaction matrix size ({matrix.shape[0]}) does not match "
65
+ f"number of features ({n_features})"
66
+ )