ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,452 @@
1
+ """Signal selection algorithms for multi-signal comparison.
2
+
3
+ This module provides intelligent signal selection algorithms to identify
4
+ the most promising signals from a large set based on various criteria:
5
+
6
+ - **Top-N**: Select best signals by a single metric
7
+ - **Uncorrelated**: Select diverse signals with low correlation
8
+ - **Pareto Frontier**: Select non-dominated signals on two metrics
9
+ - **Cluster Representatives**: Select best signal from each correlation cluster
10
+
11
+ These algorithms help reduce a large signal universe (50-200) to a manageable
12
+ subset for detailed comparison while maximizing information value.
13
+
14
+ Examples
15
+ --------
16
+ >>> from ml4t.diagnostic.evaluation.signal_selector import SignalSelector
17
+ >>>
18
+ >>> # Select top 10 by IC IR
19
+ >>> top_signals = SignalSelector.select_top_n(summary, n=10, metric="ic_ir")
20
+ >>>
21
+ >>> # Select 5 uncorrelated signals
22
+ >>> diverse = SignalSelector.select_uncorrelated(
23
+ ... summary, correlation_matrix, n=5, max_correlation=0.5
24
+ ... )
25
+ >>>
26
+ >>> # Find Pareto-optimal signals (low turnover, high IC)
27
+ >>> efficient = SignalSelector.select_pareto_frontier(
28
+ ... summary, x_metric="turnover_mean", y_metric="ic_ir"
29
+ ... )
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ from typing import TYPE_CHECKING, Any
35
+
36
+ import numpy as np
37
+ import polars as pl
38
+
39
+ if TYPE_CHECKING:
40
+ pass
41
+
42
+
43
+ class SignalSelector:
44
+ """Smart signal selection algorithms for comparison.
45
+
46
+ Provides static methods for selecting subsets of signals based on
47
+ different criteria. All methods are designed to work with summary
48
+ DataFrames from MultiSignalAnalysis.
49
+
50
+ Methods
51
+ -------
52
+ select_top_n : Select top N signals by metric
53
+ select_uncorrelated : Select diverse, uncorrelated signals
54
+ select_pareto_frontier : Select Pareto-optimal signals
55
+ select_by_cluster : Select representative from each cluster
56
+ """
57
+
58
+ @staticmethod
59
+ def select_top_n(
60
+ summary_df: pl.DataFrame,
61
+ n: int = 10,
62
+ metric: str = "ic_ir",
63
+ ascending: bool = False,
64
+ filter_significant: bool = False,
65
+ significance_col: str = "fdr_significant",
66
+ ) -> list[str]:
67
+ """Select top N signals by a single metric.
68
+
69
+ Parameters
70
+ ----------
71
+ summary_df : pl.DataFrame
72
+ Summary DataFrame with columns: signal_name, {metric}
73
+ n : int, default 10
74
+ Number of signals to select
75
+ metric : str, default "ic_ir"
76
+ Metric column to sort by
77
+ ascending : bool, default False
78
+ If True, select lowest values (e.g., for turnover)
79
+ filter_significant : bool, default False
80
+ If True, only consider signals that pass significance threshold
81
+ significance_col : str, default "fdr_significant"
82
+ Column containing significance flag
83
+
84
+ Returns
85
+ -------
86
+ list[str]
87
+ Signal names of top N signals
88
+
89
+ Examples
90
+ --------
91
+ >>> # Top 10 by IC IR (highest)
92
+ >>> top = SignalSelector.select_top_n(summary, n=10, metric="ic_ir")
93
+ >>>
94
+ >>> # Top 10 lowest turnover
95
+ >>> low_turn = SignalSelector.select_top_n(
96
+ ... summary, n=10, metric="turnover_mean", ascending=True
97
+ ... )
98
+ """
99
+ if metric not in summary_df.columns:
100
+ raise ValueError(f"Metric '{metric}' not found. Available: {summary_df.columns}")
101
+
102
+ df = summary_df
103
+
104
+ # Optionally filter to significant only
105
+ if filter_significant and significance_col in df.columns:
106
+ df = df.filter(pl.col(significance_col))
107
+
108
+ # Sort and take top N
109
+ sorted_df = df.sort(metric, descending=not ascending)
110
+ return sorted_df.head(n)["signal_name"].to_list()
111
+
112
+ @staticmethod
113
+ def select_uncorrelated(
114
+ summary_df: pl.DataFrame,
115
+ correlation_matrix: pl.DataFrame,
116
+ n: int = 5,
117
+ metric: str = "ic_ir",
118
+ min_metric_value: float | None = None,
119
+ max_correlation: float = 0.7,
120
+ ) -> list[str]:
121
+ """Select top N signals that are least correlated with each other.
122
+
123
+ Uses a greedy algorithm:
124
+ 1. Filter signals with metric >= min_metric_value (if specified)
125
+ 2. Sort remaining by metric (descending)
126
+ 3. Select best signal
127
+ 4. For each remaining, select signal with lowest max correlation
128
+ to already-selected signals, subject to max_correlation threshold
129
+ 5. Repeat until N signals selected or no more available
130
+
131
+ Parameters
132
+ ----------
133
+ summary_df : pl.DataFrame
134
+ Summary DataFrame with signal_name and metric columns
135
+ correlation_matrix : pl.DataFrame
136
+ Square correlation matrix with signal names as both index and columns
137
+ n : int, default 5
138
+ Number of signals to select
139
+ metric : str, default "ic_ir"
140
+ Metric to rank signals by (higher is better)
141
+ min_metric_value : float | None, default None
142
+ Minimum metric value to consider a signal
143
+ max_correlation : float, default 0.7
144
+ Maximum allowed correlation between selected signals
145
+
146
+ Returns
147
+ -------
148
+ list[str]
149
+ Signal names of selected uncorrelated signals
150
+
151
+ Notes
152
+ -----
153
+ This is a greedy algorithm that may not find the globally optimal
154
+ subset, but works well in practice and is O(n²) in the number of
155
+ signals.
156
+
157
+ Examples
158
+ --------
159
+ >>> # Select 5 diverse signals with IC > 0.02
160
+ >>> diverse = SignalSelector.select_uncorrelated(
161
+ ... summary, corr_matrix, n=5,
162
+ ... min_metric_value=0.02, max_correlation=0.5
163
+ ... )
164
+ """
165
+ # Get available signals and their metrics
166
+ candidates = summary_df.select(["signal_name", metric])
167
+
168
+ # Filter by minimum metric if specified
169
+ if min_metric_value is not None:
170
+ candidates = candidates.filter(pl.col(metric) >= min_metric_value)
171
+
172
+ if len(candidates) == 0:
173
+ return []
174
+
175
+ # Sort by metric descending
176
+ candidates = candidates.sort(metric, descending=True)
177
+ candidate_names = candidates["signal_name"].to_list()
178
+
179
+ # Convert correlation matrix to numpy for efficient indexing
180
+ corr_signals = correlation_matrix.columns
181
+ corr_numpy = correlation_matrix.to_numpy()
182
+
183
+ # Build name-to-index mapping
184
+ signal_to_idx = {name: i for i, name in enumerate(corr_signals)}
185
+
186
+ # Greedy selection
187
+ selected: list[str] = []
188
+ remaining = set(candidate_names)
189
+
190
+ for signal_name in candidate_names:
191
+ if signal_name not in remaining:
192
+ continue
193
+
194
+ if signal_name not in signal_to_idx:
195
+ # Signal not in correlation matrix (shouldn't happen normally)
196
+ remaining.discard(signal_name)
197
+ continue
198
+
199
+ # Check correlation with already selected signals
200
+ if len(selected) > 0:
201
+ idx = signal_to_idx[signal_name]
202
+ selected_idxs = [signal_to_idx[s] for s in selected]
203
+ correlations = np.abs(corr_numpy[idx, selected_idxs])
204
+ max_corr = np.max(correlations)
205
+
206
+ if max_corr > max_correlation:
207
+ remaining.discard(signal_name)
208
+ continue
209
+
210
+ # Select this signal
211
+ selected.append(signal_name)
212
+ remaining.discard(signal_name)
213
+
214
+ if len(selected) >= n:
215
+ break
216
+
217
+ return selected
218
+
219
+ @staticmethod
220
+ def select_pareto_frontier(
221
+ summary_df: pl.DataFrame,
222
+ x_metric: str = "turnover_mean",
223
+ y_metric: str = "ic_ir",
224
+ minimize_x: bool = True,
225
+ maximize_y: bool = True,
226
+ ) -> list[str]:
227
+ """Select signals on the Pareto frontier (efficient frontier).
228
+
229
+ A signal is Pareto-optimal if no other signal is strictly better
230
+ on both metrics. This finds signals that represent different
231
+ trade-offs between the two metrics.
232
+
233
+ Parameters
234
+ ----------
235
+ summary_df : pl.DataFrame
236
+ Summary DataFrame with signal_name, x_metric, y_metric columns
237
+ x_metric : str, default "turnover_mean"
238
+ First metric (typically to minimize, like turnover)
239
+ y_metric : str, default "ic_ir"
240
+ Second metric (typically to maximize, like IC)
241
+ minimize_x : bool, default True
242
+ If True, lower x values are better
243
+ maximize_y : bool, default True
244
+ If True, higher y values are better
245
+
246
+ Returns
247
+ -------
248
+ list[str]
249
+ Signal names on the Pareto frontier, sorted by x_metric
250
+
251
+ Notes
252
+ -----
253
+ The Pareto frontier helps identify signals that represent different
254
+ trade-offs. For example, one signal might have the highest IC but
255
+ also the highest turnover, while another has moderate IC with low
256
+ turnover. Both are Pareto-optimal.
257
+
258
+ Time complexity: O(n²) where n is number of signals.
259
+
260
+ Examples
261
+ --------
262
+ >>> # Find signals with best IC vs turnover trade-off
263
+ >>> frontier = SignalSelector.select_pareto_frontier(
264
+ ... summary, x_metric="turnover_mean", y_metric="ic_ir"
265
+ ... )
266
+ >>> print(f"{len(frontier)} Pareto-optimal signals")
267
+ """
268
+ if x_metric not in summary_df.columns or y_metric not in summary_df.columns:
269
+ raise ValueError(
270
+ f"Metrics not found. Required: {x_metric}, {y_metric}. "
271
+ f"Available: {summary_df.columns}"
272
+ )
273
+
274
+ # Extract data
275
+ data = summary_df.select(["signal_name", x_metric, y_metric]).to_numpy()
276
+ names = data[:, 0].tolist()
277
+ x_values = data[:, 1].astype(float)
278
+ y_values = data[:, 2].astype(float)
279
+
280
+ # Convert to "higher is better" for comparison
281
+ if minimize_x:
282
+ x_values = -x_values
283
+ if not maximize_y:
284
+ y_values = -y_values
285
+
286
+ # Find Pareto frontier
287
+ n = len(names)
288
+ pareto_mask = np.ones(n, dtype=bool)
289
+
290
+ for i in range(n):
291
+ if not pareto_mask[i]:
292
+ continue
293
+ for j in range(n):
294
+ if i == j or not pareto_mask[j]:
295
+ continue
296
+ # Check if j dominates i (j better on both metrics)
297
+ if x_values[j] >= x_values[i] and y_values[j] >= y_values[i]:
298
+ if x_values[j] > x_values[i] or y_values[j] > y_values[i]:
299
+ pareto_mask[i] = False
300
+ break
301
+
302
+ # Sort by original x_metric (not negated)
303
+ x_original = data[:, 1].astype(float)
304
+ pareto_with_x = [(names[i], x_original[i]) for i in range(n) if pareto_mask[i]]
305
+ pareto_with_x.sort(key=lambda x: x[1], reverse=not minimize_x)
306
+
307
+ return [name for name, _ in pareto_with_x]
308
+
309
+ @staticmethod
310
+ def select_by_cluster(
311
+ correlation_matrix: pl.DataFrame,
312
+ summary_df: pl.DataFrame,
313
+ n_clusters: int = 5,
314
+ signals_per_cluster: int = 1,
315
+ metric: str = "ic_ir",
316
+ linkage_method: str = "ward",
317
+ ) -> list[str]:
318
+ """Select representative signals from each correlation cluster.
319
+
320
+ Uses hierarchical clustering on correlation distance to group
321
+ similar signals, then selects the best signal(s) from each cluster.
322
+
323
+ Parameters
324
+ ----------
325
+ correlation_matrix : pl.DataFrame
326
+ Square correlation matrix (signals as columns)
327
+ summary_df : pl.DataFrame
328
+ Summary with signal_name and metric columns
329
+ n_clusters : int, default 5
330
+ Number of clusters to create
331
+ signals_per_cluster : int, default 1
332
+ Number of signals to select from each cluster
333
+ metric : str, default "ic_ir"
334
+ Metric for selecting best within cluster
335
+ linkage_method : str, default "ward"
336
+ Hierarchical clustering linkage method
337
+
338
+ Returns
339
+ -------
340
+ list[str]
341
+ Selected signal names (one per cluster, sorted by metric)
342
+
343
+ Notes
344
+ -----
345
+ This method is useful for finding truly independent signal sources.
346
+ "100 signals = 3 unique bets" pattern can be revealed by clustering.
347
+
348
+ Requires scipy for hierarchical clustering.
349
+
350
+ Examples
351
+ --------
352
+ >>> # Select best signal from each of 5 clusters
353
+ >>> reps = SignalSelector.select_by_cluster(
354
+ ... corr_matrix, summary, n_clusters=5
355
+ ... )
356
+ """
357
+ try:
358
+ from scipy.cluster.hierarchy import cut_tree, linkage
359
+ except ImportError as err:
360
+ raise ImportError(
361
+ "scipy required for cluster selection. Install with: pip install scipy"
362
+ ) from err
363
+
364
+ # Get signal names and correlation matrix
365
+ signal_names = correlation_matrix.columns
366
+ corr_np = correlation_matrix.to_numpy()
367
+
368
+ # Convert correlation to distance (1 - |correlation|)
369
+ distance = 1 - np.abs(corr_np)
370
+ np.fill_diagonal(distance, 0)
371
+
372
+ # Perform hierarchical clustering
373
+ # linkage expects condensed distance matrix
374
+ n = len(signal_names)
375
+ condensed = distance[np.triu_indices(n, k=1)]
376
+ linkage_matrix = linkage(condensed, method=linkage_method)
377
+
378
+ # Cut tree to get cluster labels
379
+ cluster_labels = cut_tree(linkage_matrix, n_clusters=n_clusters).flatten()
380
+
381
+ # Build cluster -> signals mapping
382
+ clusters: dict[int, list[str]] = {i: [] for i in range(n_clusters)}
383
+ for i, signal in enumerate(signal_names):
384
+ clusters[cluster_labels[i]].append(signal)
385
+
386
+ # Get metric values from summary
387
+ metric_lookup = dict(
388
+ zip(
389
+ summary_df["signal_name"].to_list(),
390
+ summary_df[metric].to_list(),
391
+ )
392
+ )
393
+
394
+ # Select best signal(s) from each cluster
395
+ selected: list[str] = []
396
+ for cluster_id in range(n_clusters):
397
+ cluster_signals = clusters[cluster_id]
398
+ if not cluster_signals:
399
+ continue
400
+
401
+ # Sort by metric and take top signals_per_cluster
402
+ sorted_signals = sorted(
403
+ cluster_signals,
404
+ key=lambda s: metric_lookup.get(s, float("-inf")),
405
+ reverse=True,
406
+ )
407
+ selected.extend(sorted_signals[:signals_per_cluster])
408
+
409
+ # Sort final list by metric
410
+ selected.sort(
411
+ key=lambda s: metric_lookup.get(s, float("-inf")),
412
+ reverse=True,
413
+ )
414
+
415
+ return selected
416
+
417
+ @staticmethod
418
+ def get_selection_info(
419
+ summary_df: pl.DataFrame,
420
+ selected_signals: list[str],
421
+ method: str,
422
+ **method_params: Any,
423
+ ) -> dict[str, Any]:
424
+ """Get information about a signal selection for documentation.
425
+
426
+ Parameters
427
+ ----------
428
+ summary_df : pl.DataFrame
429
+ Summary DataFrame
430
+ selected_signals : list[str]
431
+ List of selected signal names
432
+ method : str
433
+ Selection method name ("top_n", "uncorrelated", "pareto", "cluster")
434
+ **method_params : Any
435
+ Parameters used for selection
436
+
437
+ Returns
438
+ -------
439
+ dict
440
+ Dictionary with selection metadata for reporting
441
+ """
442
+ # Get metrics for selected signals
443
+ selected_data = summary_df.filter(pl.col("signal_name").is_in(selected_signals))
444
+
445
+ return {
446
+ "method": method,
447
+ "n_selected": len(selected_signals),
448
+ "n_total": len(summary_df),
449
+ "signals": selected_signals,
450
+ "method_params": method_params,
451
+ "selected_summary": selected_data.to_dicts(),
452
+ }
@@ -0,0 +1,139 @@
1
+ """Statistical test registry for evaluation framework.
2
+
3
+ This module provides a centralized registry for statistical tests
4
+ used in the evaluation framework, including tier defaults.
5
+ """
6
+
7
+ from collections.abc import Callable
8
+ from typing import Any
9
+
10
+
11
+ class StatTestRegistry:
12
+ """Registry of statistical tests for evaluation.
13
+
14
+ The StatTestRegistry provides a centralized place to register and query
15
+ statistical tests, including their tier defaults.
16
+
17
+ Attributes
18
+ ----------
19
+ _tests : dict[str, Callable]
20
+ Mapping of test names to test functions
21
+ _tier_defaults : dict[int, list[str]]
22
+ Default tests for each evaluation tier
23
+
24
+ Examples
25
+ --------
26
+ >>> registry = StatTestRegistry()
27
+ >>> registry.register("dsr", dsr_func, tiers=[1])
28
+ >>> func = registry.get("dsr")
29
+ """
30
+
31
+ _instance: "StatTestRegistry | None" = None
32
+
33
+ def __init__(self) -> None:
34
+ """Initialize empty registry."""
35
+ self._tests: dict[str, Callable[..., Any]] = {}
36
+ self._tier_defaults: dict[int, list[str]] = {1: [], 2: [], 3: []}
37
+
38
+ @classmethod
39
+ def default(cls) -> "StatTestRegistry":
40
+ """Get or create the default singleton registry instance.
41
+
42
+ Returns
43
+ -------
44
+ StatTestRegistry
45
+ The default registry instance with standard tests registered
46
+ """
47
+ if cls._instance is None:
48
+ cls._instance = cls()
49
+ cls._instance._register_defaults()
50
+ return cls._instance
51
+
52
+ @classmethod
53
+ def reset_default(cls) -> None:
54
+ """Reset the default singleton instance (primarily for testing)."""
55
+ cls._instance = None
56
+
57
+ def register(
58
+ self,
59
+ name: str,
60
+ func: Callable[..., Any],
61
+ tiers: list[int] | None = None,
62
+ ) -> None:
63
+ """Register a statistical test with the registry.
64
+
65
+ Parameters
66
+ ----------
67
+ name : str
68
+ Unique name for the test
69
+ func : Callable
70
+ Function that performs the test.
71
+ Should return a dict with test results
72
+ tiers : list[int], optional
73
+ Evaluation tiers where this test is a default
74
+ """
75
+ self._tests[name] = func
76
+ if tiers:
77
+ for tier in tiers:
78
+ if tier in self._tier_defaults and name not in self._tier_defaults[tier]:
79
+ self._tier_defaults[tier].append(name)
80
+
81
+ def get(self, name: str) -> Callable[..., Any]:
82
+ """Get a test function by name.
83
+
84
+ Parameters
85
+ ----------
86
+ name : str
87
+ Name of the test
88
+
89
+ Returns
90
+ -------
91
+ Callable
92
+ The test function
93
+
94
+ Raises
95
+ ------
96
+ KeyError
97
+ If test name is not registered
98
+ """
99
+ if name not in self._tests:
100
+ raise KeyError(f"Unknown test: {name}. Available: {list(self._tests.keys())}")
101
+ return self._tests[name]
102
+
103
+ def get_by_tier(self, tier: int) -> list[str]:
104
+ """Get default tests for a specific tier.
105
+
106
+ Parameters
107
+ ----------
108
+ tier : int
109
+ Evaluation tier (1, 2, or 3)
110
+
111
+ Returns
112
+ -------
113
+ list[str]
114
+ List of default test names for the tier
115
+ """
116
+ return self._tier_defaults.get(tier, []).copy()
117
+
118
+ def list_tests(self) -> list[str]:
119
+ """List all registered test names.
120
+
121
+ Returns
122
+ -------
123
+ list[str]
124
+ Sorted list of test names
125
+ """
126
+ return sorted(self._tests.keys())
127
+
128
+ def __contains__(self, name: str) -> bool:
129
+ """Check if a test is registered."""
130
+ return name in self._tests
131
+
132
+ def _register_defaults(self) -> None:
133
+ """Register default statistical tests."""
134
+ from . import stats
135
+
136
+ self.register("dsr", stats.deflated_sharpe_ratio_from_statistics, tiers=[1])
137
+ self.register("hac_ic", stats.robust_ic, tiers=[2])
138
+ self.register("fdr", stats.benjamini_hochberg_fdr, tiers=[1])
139
+ self.register("whites_reality_check", stats.whites_reality_check, tiers=[])
@@ -0,0 +1,97 @@
1
+ """Stationarity testing for time series features.
2
+
3
+ This module provides statistical tests for detecting unit roots and assessing
4
+ stationarity of financial time series:
5
+
6
+ - Augmented Dickey-Fuller (ADF) test - tests for unit root (H0: non-stationary)
7
+ - KPSS test - tests for stationarity (H0: stationary)
8
+ - Phillips-Perron (PP) test - robust alternative to ADF (H0: non-stationary)
9
+
10
+ Stationarity is a critical assumption for many time series models and
11
+ feature engineering techniques. Non-stationary series require transformation
12
+ (differencing, detrending) before use in predictive models.
13
+
14
+ Key Differences Between Tests:
15
+ - ADF: Parametric test with lagged differences, H0 = unit root (non-stationary)
16
+ - PP: Non-parametric correction for serial correlation, H0 = unit root (non-stationary)
17
+ - KPSS: H0 = stationarity (opposite interpretation!)
18
+ - Use multiple tests together for robust stationarity assessment
19
+ - Stationary: ADF/PP rejects + KPSS fails to reject
20
+ - Non-stationary: ADF/PP fails to reject + KPSS rejects
21
+ - Quasi-stationary: Both reject or both fail to reject (inconclusive)
22
+
23
+ Phillips-Perron vs ADF:
24
+ - PP uses non-parametric Newey-West correction for heteroscedasticity
25
+ - PP estimates regression with only 1 lag vs ADF's multiple lags
26
+ - PP more robust to general forms of serial correlation
27
+ - Both have same null hypothesis: unit root exists (non-stationary)
28
+
29
+ References:
30
+ - Dickey, D. A., & Fuller, W. A. (1979). Distribution of the estimators
31
+ for autoregressive time series with a unit root.
32
+ - Phillips, P. C., & Perron, P. (1988). Testing for a unit root in time
33
+ series regression. Biometrika, 75(2), 335-346.
34
+ - MacKinnon, J. G. (1994). Approximate asymptotic distribution functions
35
+ for unit-root and cointegration tests.
36
+ - Kwiatkowski, D., Phillips, P. C., Schmidt, P., & Shin, Y. (1992).
37
+ Testing the null hypothesis of stationarity against the alternative
38
+ of a unit root. Journal of Econometrics, 54(1-3), 159-178.
39
+
40
+ Example:
41
+ >>> import numpy as np
42
+ >>> from ml4t.diagnostic.evaluation.stationarity import adf_test, kpss_test
43
+ >>>
44
+ >>> # White noise (stationary)
45
+ >>> white_noise = np.random.randn(1000)
46
+ >>> adf = adf_test(white_noise)
47
+ >>> kpss = kpss_test(white_noise)
48
+ >>> print(f"ADF stationary: {adf.is_stationary}") # Should be True
49
+ >>> print(f"KPSS stationary: {kpss.is_stationary}") # Should be True
50
+ >>>
51
+ >>> # Random walk (non-stationary)
52
+ >>> random_walk = np.cumsum(np.random.randn(1000))
53
+ >>> adf = adf_test(random_walk)
54
+ >>> kpss = kpss_test(random_walk)
55
+ >>> print(f"ADF stationary: {adf.is_stationary}") # Should be False
56
+ >>> print(f"KPSS stationary: {kpss.is_stationary}") # Should be False
57
+ >>>
58
+ >>> # Comprehensive analysis with all tests
59
+ >>> from ml4t.diagnostic.evaluation.stationarity import analyze_stationarity
60
+ >>> result = analyze_stationarity(random_walk)
61
+ >>> print(result.summary())
62
+ """
63
+
64
+ # Import from submodules and re-export
65
+ from ml4t.diagnostic.evaluation.stationarity.analysis import (
66
+ StationarityAnalysisResult,
67
+ analyze_stationarity,
68
+ )
69
+ from ml4t.diagnostic.evaluation.stationarity.augmented_dickey_fuller import (
70
+ ADFResult,
71
+ adf_test,
72
+ )
73
+ from ml4t.diagnostic.evaluation.stationarity.kpss_test import (
74
+ KPSSResult,
75
+ kpss_test,
76
+ )
77
+ from ml4t.diagnostic.evaluation.stationarity.phillips_perron import (
78
+ HAS_ARCH,
79
+ PPResult,
80
+ pp_test,
81
+ )
82
+
83
+ __all__ = [
84
+ # ADF test
85
+ "adf_test",
86
+ "ADFResult",
87
+ # KPSS test
88
+ "kpss_test",
89
+ "KPSSResult",
90
+ # PP test
91
+ "pp_test",
92
+ "PPResult",
93
+ "HAS_ARCH",
94
+ # Comprehensive analysis
95
+ "analyze_stationarity",
96
+ "StationarityAnalysisResult",
97
+ ]