ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,263 @@
1
+ """Group partitioning strategies for CPCV.
2
+
3
+ This module handles partitioning the timeline into groups:
4
+ - Contiguous partitioning (equal-sized time slices)
5
+ - Session-aligned partitioning (respects trading session boundaries)
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from collections.abc import Callable
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ import numpy as np
14
+ from numpy.typing import NDArray
15
+
16
+ if TYPE_CHECKING:
17
+ import pandas as pd
18
+ import polars as pl
19
+
20
+
21
+ def create_contiguous_partitions(
22
+ n_samples: int,
23
+ n_groups: int,
24
+ ) -> list[tuple[int, int]]:
25
+ """Create boundaries for contiguous groups.
26
+
27
+ Partitions n_samples into n_groups approximately equal-sized groups.
28
+ Earlier groups get extra samples when n_samples is not evenly divisible.
29
+
30
+ Parameters
31
+ ----------
32
+ n_samples : int
33
+ Total number of samples.
34
+ n_groups : int
35
+ Number of groups to create.
36
+
37
+ Returns
38
+ -------
39
+ boundaries : list of tuple
40
+ List of (start_idx, end_idx) for each group.
41
+ end_idx is exclusive (standard Python convention).
42
+
43
+ Raises
44
+ ------
45
+ ValueError
46
+ If boundaries don't satisfy CPCV invariants.
47
+
48
+ Examples
49
+ --------
50
+ >>> create_contiguous_partitions(100, 5)
51
+ [(0, 20), (20, 40), (40, 60), (60, 80), (80, 100)]
52
+
53
+ >>> create_contiguous_partitions(103, 5)
54
+ [(0, 21), (21, 42), (42, 62), (62, 82), (82, 103)]
55
+ """
56
+ base_size = n_samples // n_groups
57
+ remainder = n_samples % n_groups
58
+
59
+ boundaries = []
60
+ current_start = 0
61
+
62
+ for i in range(n_groups):
63
+ # Add extra sample to first 'remainder' groups
64
+ group_size = base_size + (1 if i < remainder else 0)
65
+ group_end = current_start + group_size
66
+
67
+ boundaries.append((current_start, group_end))
68
+ current_start = group_end
69
+
70
+ # Validate invariants
71
+ validate_contiguous_partitions(boundaries, n_samples)
72
+
73
+ return boundaries
74
+
75
+
76
+ def validate_contiguous_partitions(
77
+ boundaries: list[tuple[int, int]],
78
+ n_samples: int,
79
+ ) -> None:
80
+ """Validate CPCV group boundary invariants.
81
+
82
+ Ensures:
83
+ 1. All samples are covered (no gaps)
84
+ 2. No overlap between groups
85
+ 3. Groups are contiguous
86
+
87
+ Parameters
88
+ ----------
89
+ boundaries : list of tuple
90
+ List of (start_idx, end_idx) for each group.
91
+ n_samples : int
92
+ Total number of samples.
93
+
94
+ Raises
95
+ ------
96
+ ValueError
97
+ If any invariant is violated.
98
+ """
99
+ if not boundaries:
100
+ raise ValueError("CPCV invariant violated: no group boundaries created")
101
+
102
+ # Check first boundary starts at 0
103
+ if boundaries[0][0] != 0:
104
+ raise ValueError(
105
+ f"CPCV invariant violated: first group must start at 0, got {boundaries[0][0]}"
106
+ )
107
+
108
+ # Check last boundary ends at n_samples
109
+ if boundaries[-1][1] != n_samples:
110
+ raise ValueError(
111
+ f"CPCV invariant violated: last group must end at {n_samples}, got {boundaries[-1][1]}"
112
+ )
113
+
114
+ # Check contiguity (each group starts where previous ended)
115
+ for i in range(1, len(boundaries)):
116
+ prev_end = boundaries[i - 1][1]
117
+ curr_start = boundaries[i][0]
118
+ if curr_start != prev_end:
119
+ raise ValueError(
120
+ f"CPCV invariant violated: gap between group {i - 1} (ends at {prev_end}) "
121
+ f"and group {i} (starts at {curr_start})"
122
+ )
123
+
124
+ # Check each group is non-empty
125
+ for i, (start, end) in enumerate(boundaries):
126
+ if end <= start:
127
+ raise ValueError(
128
+ f"CPCV invariant violated: group {i} is empty or invalid (start={start}, end={end})"
129
+ )
130
+
131
+
132
+ def create_session_partitions(
133
+ X: pl.DataFrame | pd.DataFrame,
134
+ session_col: str,
135
+ n_groups: int,
136
+ session_to_indices_fn: Callable[
137
+ [pl.DataFrame | pd.DataFrame, str],
138
+ tuple[list[Any], dict[Any, NDArray[np.intp]]],
139
+ ],
140
+ ) -> list[NDArray[np.intp]]:
141
+ """Create exact index arrays per group, aligned to session boundaries.
142
+
143
+ Unlike contiguous partitioning which returns (start, end) ranges,
144
+ this method returns EXACT index arrays for each group. This is critical
145
+ for correct behavior with non-contiguous or interleaved data.
146
+
147
+ Parameters
148
+ ----------
149
+ X : DataFrame
150
+ Data with session column.
151
+ session_col : str
152
+ Name of column containing session identifiers.
153
+ n_groups : int
154
+ Number of groups to create.
155
+ session_to_indices_fn : callable
156
+ Function that returns (ordered_sessions, session_to_indices_dict).
157
+ Typically from BaseSplitter._session_to_indices.
158
+
159
+ Returns
160
+ -------
161
+ group_indices : list of np.ndarray
162
+ List of numpy arrays containing exact row indices for each group.
163
+ Each array contains the indices for all rows belonging to sessions
164
+ in that group.
165
+
166
+ Raises
167
+ ------
168
+ ValueError
169
+ If not enough sessions for the requested number of groups.
170
+
171
+ Notes
172
+ -----
173
+ The key difference from contiguous partitioning is that we track
174
+ exact indices rather than (start, end) boundaries. This prevents
175
+ incorrect index ranges when data is interleaved by asset within sessions.
176
+ """
177
+ # Get session -> indices mapping
178
+ ordered_sessions, session_to_indices = session_to_indices_fn(X, session_col)
179
+ n_sessions = len(ordered_sessions)
180
+
181
+ if n_sessions < n_groups:
182
+ raise ValueError(
183
+ f"Not enough sessions ({n_sessions}) for {n_groups} groups. "
184
+ f"Need at least {n_groups} sessions."
185
+ )
186
+
187
+ # Partition sessions into groups
188
+ base_sessions_per_group = n_sessions // n_groups
189
+ remainder = n_sessions % n_groups
190
+
191
+ group_indices_list = []
192
+ current_session_idx = 0
193
+
194
+ for i in range(n_groups):
195
+ # Add extra session to first 'remainder' groups
196
+ sessions_in_group = base_sessions_per_group + (1 if i < remainder else 0)
197
+ session_group_end = current_session_idx + sessions_in_group
198
+
199
+ # Get sessions for this group
200
+ group_sessions = ordered_sessions[current_session_idx:session_group_end]
201
+
202
+ # Collect EXACT indices for sessions in this group
203
+ indices_arrays = [session_to_indices[s] for s in group_sessions]
204
+ if indices_arrays:
205
+ group_indices = np.concatenate(indices_arrays)
206
+ # Sort for predictable ordering
207
+ group_indices = np.sort(group_indices)
208
+ else:
209
+ group_indices = np.array([], dtype=np.intp)
210
+
211
+ group_indices_list.append(group_indices)
212
+ current_session_idx = session_group_end
213
+
214
+ return group_indices_list
215
+
216
+
217
+ def boundaries_to_indices(
218
+ boundaries: list[tuple[int, int]],
219
+ groups: tuple[int, ...],
220
+ ) -> NDArray[np.intp]:
221
+ """Convert group boundaries to flat index array for selected groups.
222
+
223
+ Parameters
224
+ ----------
225
+ boundaries : list of tuple
226
+ List of (start_idx, end_idx) for each group.
227
+ groups : tuple of int
228
+ Which groups to include.
229
+
230
+ Returns
231
+ -------
232
+ indices : np.ndarray
233
+ Sorted array of indices for selected groups.
234
+ """
235
+ # Use numpy concatenation instead of Python list extend for performance
236
+ ranges = [np.arange(boundaries[g][0], boundaries[g][1], dtype=np.intp) for g in groups]
237
+ if not ranges:
238
+ return np.array([], dtype=np.intp)
239
+ return np.concatenate(ranges)
240
+
241
+
242
+ def exact_indices_to_array(
243
+ group_indices_list: list[NDArray[np.intp]],
244
+ groups: tuple[int, ...],
245
+ ) -> NDArray[np.intp]:
246
+ """Concatenate exact index arrays for selected groups.
247
+
248
+ Parameters
249
+ ----------
250
+ group_indices_list : list of np.ndarray
251
+ List of exact index arrays for each group.
252
+ groups : tuple of int
253
+ Which groups to include.
254
+
255
+ Returns
256
+ -------
257
+ indices : np.ndarray
258
+ Sorted array of indices for selected groups.
259
+ """
260
+ arrays = [group_indices_list[g] for g in groups]
261
+ if not arrays or all(len(a) == 0 for a in arrays):
262
+ return np.array([], dtype=np.intp)
263
+ return np.sort(np.concatenate(arrays))
@@ -0,0 +1,379 @@
1
+ """Purging engine for CPCV.
2
+
3
+ This module implements the core purging and embargo logic:
4
+ - Mask-based purging (efficient for large datasets)
5
+ - Single-asset and multi-asset purging strategies
6
+ - Segment-based purging for temporal coherence
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ import numpy as np
14
+ from numpy.typing import NDArray
15
+
16
+ from ml4t.diagnostic.core.purging import apply_purging_and_embargo
17
+ from ml4t.diagnostic.splitters.cpcv.windows import (
18
+ find_contiguous_segments,
19
+ timestamp_window_from_indices,
20
+ )
21
+ from ml4t.diagnostic.splitters.utils import convert_indices_to_timestamps
22
+
23
+ if TYPE_CHECKING:
24
+ import pandas as pd
25
+
26
+
27
+ def apply_single_asset_purging(
28
+ train_indices: NDArray[np.intp],
29
+ test_group_indices: tuple[int, ...],
30
+ group_boundaries: list[tuple[int, int]],
31
+ n_samples: int,
32
+ timestamps: pd.DatetimeIndex | None,
33
+ label_horizon: int | pd.Timedelta,
34
+ embargo_size: int | pd.Timedelta | None,
35
+ embargo_pct: float | None,
36
+ group_indices_list: list[NDArray[np.intp]] | None = None,
37
+ ) -> NDArray[np.intp]:
38
+ """Apply purging for single-asset data.
39
+
40
+ For each test group, removes training samples that would cause
41
+ look-ahead bias due to label overlap or temporal proximity.
42
+
43
+ Parameters
44
+ ----------
45
+ train_indices : ndarray
46
+ Initial training indices.
47
+ test_group_indices : tuple of int
48
+ Indices of groups used for testing.
49
+ group_boundaries : list of tuple
50
+ Boundaries (start, end) for each group.
51
+ n_samples : int
52
+ Total number of samples.
53
+ timestamps : pd.DatetimeIndex, optional
54
+ Timestamps for time-based purging.
55
+ label_horizon : int or pd.Timedelta
56
+ Forward-looking period of labels.
57
+ embargo_size : int or pd.Timedelta, optional
58
+ Buffer period after test set.
59
+ embargo_pct : float, optional
60
+ Embargo as percentage of samples.
61
+ group_indices_list : list of ndarray, optional
62
+ Exact indices per group (for session-aligned mode).
63
+
64
+ Returns
65
+ -------
66
+ clean_indices : ndarray
67
+ Training indices after purging.
68
+ """
69
+ for test_group_idx in test_group_indices:
70
+ # Compute purge window bounds
71
+ if group_indices_list is not None and timestamps is not None:
72
+ # Session-aligned mode: use actual timestamps from test indices
73
+ test_indices = group_indices_list[test_group_idx]
74
+ window = timestamp_window_from_indices(test_indices, timestamps)
75
+ if window is None:
76
+ # Empty test group - skip purging for this group
77
+ continue
78
+ test_start_time = window.start
79
+ test_end_time = window.end_exclusive
80
+ else:
81
+ # Standard mode: use boundaries
82
+ test_start_idx, test_end_idx = group_boundaries[test_group_idx]
83
+ test_start_time, test_end_time = convert_indices_to_timestamps(
84
+ test_start_idx,
85
+ test_end_idx,
86
+ timestamps,
87
+ )
88
+
89
+ # Apply purging and embargo for this test group
90
+ train_indices = apply_purging_and_embargo(
91
+ train_indices=train_indices,
92
+ test_start=test_start_time,
93
+ test_end=test_end_time,
94
+ label_horizon=label_horizon,
95
+ embargo_size=embargo_size,
96
+ embargo_pct=embargo_pct,
97
+ n_samples=n_samples,
98
+ timestamps=timestamps,
99
+ )
100
+
101
+ return train_indices
102
+
103
+
104
+ def apply_multi_asset_purging(
105
+ train_indices: NDArray[np.intp],
106
+ test_group_indices: tuple[int, ...],
107
+ group_boundaries: list[tuple[int, int]],
108
+ n_samples: int,
109
+ timestamps: pd.DatetimeIndex | None,
110
+ groups_array: NDArray[Any],
111
+ label_horizon: int | pd.Timedelta,
112
+ embargo_size: int | pd.Timedelta | None,
113
+ embargo_pct: float | None,
114
+ group_indices_list: list[NDArray[np.intp]] | None = None,
115
+ ) -> NDArray[np.intp]:
116
+ """Apply purging for multi-asset data with per-asset isolation.
117
+
118
+ This method correctly handles non-contiguous test groups by applying
119
+ purging for each contiguous segment of test data separately per asset.
120
+
121
+ Parameters
122
+ ----------
123
+ train_indices : ndarray
124
+ Initial training indices.
125
+ test_group_indices : tuple of int
126
+ Indices of groups used for testing.
127
+ group_boundaries : list of tuple
128
+ Boundaries (start, end) for each group.
129
+ n_samples : int
130
+ Total number of samples.
131
+ timestamps : pd.DatetimeIndex, optional
132
+ Timestamps for time-based purging.
133
+ groups_array : ndarray
134
+ Asset labels for each sample.
135
+ label_horizon : int or pd.Timedelta
136
+ Forward-looking period of labels.
137
+ embargo_size : int or pd.Timedelta, optional
138
+ Buffer period after test set.
139
+ embargo_pct : float, optional
140
+ Embargo as percentage of samples.
141
+ group_indices_list : list of ndarray, optional
142
+ Exact indices per group (for session-aligned mode).
143
+
144
+ Returns
145
+ -------
146
+ clean_indices : ndarray
147
+ Training indices after per-asset purging.
148
+ """
149
+ if len(groups_array) != n_samples:
150
+ raise ValueError(
151
+ f"groups length ({len(groups_array)}) must match number of samples ({n_samples})",
152
+ )
153
+
154
+ # Prepare test groups data for contiguous segment detection
155
+ test_groups_data = prepare_test_groups_data(
156
+ test_group_indices, group_boundaries, group_indices_list
157
+ )
158
+
159
+ # Apply purging per asset
160
+ final_train_indices: list[int] = []
161
+ unique_assets = np.unique(groups_array)
162
+
163
+ for asset_id in unique_assets:
164
+ # Process this asset's training data with purging
165
+ asset_train = process_asset_purging(
166
+ asset_id=asset_id,
167
+ groups_array=groups_array,
168
+ train_indices=train_indices,
169
+ test_groups_data=test_groups_data,
170
+ n_samples=n_samples,
171
+ timestamps=timestamps,
172
+ label_horizon=label_horizon,
173
+ embargo_size=embargo_size,
174
+ embargo_pct=embargo_pct,
175
+ group_indices_list=group_indices_list,
176
+ )
177
+ final_train_indices.extend(asset_train)
178
+
179
+ # Sort for deterministic output
180
+ return np.sort(np.array(final_train_indices, dtype=np.intp))
181
+
182
+
183
+ def prepare_test_groups_data(
184
+ test_group_indices: tuple[int, ...],
185
+ group_boundaries: list[tuple[int, int]],
186
+ group_indices_list: list[NDArray[np.intp]] | None = None,
187
+ ) -> list[tuple[int, int, int, NDArray[np.intp] | None]]:
188
+ """Prepare and sort test groups data for contiguous segment detection.
189
+
190
+ Parameters
191
+ ----------
192
+ test_group_indices : tuple of int
193
+ Which groups are used for testing.
194
+ group_boundaries : list of tuple
195
+ Boundaries (start, end) for each group.
196
+ group_indices_list : list of ndarray, optional
197
+ Exact indices per group (for session-aligned mode).
198
+
199
+ Returns
200
+ -------
201
+ test_groups_data : list of tuple
202
+ Sorted list of (group_idx, start_idx, end_idx, exact_indices).
203
+ In session-aligned mode, exact_indices contains the actual row indices;
204
+ otherwise it's None.
205
+ """
206
+ test_groups_data: list[tuple[int, int, int, NDArray[np.intp] | None]] = []
207
+ for test_group_idx in test_group_indices:
208
+ test_start_idx, test_end_idx = group_boundaries[test_group_idx]
209
+ exact_indices = (
210
+ group_indices_list[test_group_idx] if group_indices_list is not None else None
211
+ )
212
+ test_groups_data.append((test_group_idx, test_start_idx, test_end_idx, exact_indices))
213
+
214
+ # Sort test groups by start index to identify contiguous segments
215
+ test_groups_data.sort(key=lambda x: x[1])
216
+ return test_groups_data
217
+
218
+
219
+ def process_asset_purging(
220
+ asset_id: Any,
221
+ groups_array: NDArray[Any],
222
+ train_indices: NDArray[np.intp],
223
+ test_groups_data: list[tuple[int, int, int, NDArray[np.intp] | None]],
224
+ n_samples: int,
225
+ timestamps: pd.DatetimeIndex | None,
226
+ label_horizon: int | pd.Timedelta,
227
+ embargo_size: int | pd.Timedelta | None,
228
+ embargo_pct: float | None,
229
+ group_indices_list: list[NDArray[np.intp]] | None = None,
230
+ ) -> list[int]:
231
+ """Process purging for a single asset across all test segments.
232
+
233
+ Parameters
234
+ ----------
235
+ asset_id : any
236
+ Identifier for this asset.
237
+ groups_array : ndarray
238
+ Asset labels for all samples.
239
+ train_indices : ndarray
240
+ Candidate training indices.
241
+ test_groups_data : list of tuple
242
+ Test group information from prepare_test_groups_data.
243
+ n_samples : int
244
+ Total number of samples.
245
+ timestamps : pd.DatetimeIndex, optional
246
+ Timestamps for time-based purging.
247
+ label_horizon : int or pd.Timedelta
248
+ Forward-looking period of labels.
249
+ embargo_size : int or pd.Timedelta, optional
250
+ Buffer period after test set.
251
+ embargo_pct : float, optional
252
+ Embargo as percentage of samples.
253
+ group_indices_list : list of ndarray, optional
254
+ Exact indices per group (for session-aligned mode).
255
+
256
+ Returns
257
+ -------
258
+ clean_indices : list of int
259
+ Training indices for this asset after purging.
260
+ """
261
+ # Find indices for this asset
262
+ asset_mask = groups_array == asset_id
263
+ asset_indices = np.where(asset_mask)[0]
264
+
265
+ # Get train indices for this asset
266
+ asset_train_indices = np.intersect1d(train_indices, asset_indices)
267
+
268
+ if len(asset_train_indices) == 0:
269
+ return []
270
+
271
+ # Find contiguous segments of test groups for this asset
272
+ contiguous_segments = find_contiguous_segments(
273
+ test_groups_data,
274
+ asset_indices,
275
+ )
276
+
277
+ # If no test data for this asset, keep all training data
278
+ if not contiguous_segments:
279
+ return asset_train_indices.tolist()
280
+
281
+ # Apply purging for each contiguous segment
282
+ return apply_segment_purging(
283
+ asset_train_indices=asset_train_indices,
284
+ contiguous_segments=contiguous_segments,
285
+ n_samples=n_samples,
286
+ timestamps=timestamps,
287
+ label_horizon=label_horizon,
288
+ embargo_size=embargo_size,
289
+ embargo_pct=embargo_pct,
290
+ group_indices_list=group_indices_list,
291
+ )
292
+
293
+
294
+ def apply_segment_purging(
295
+ asset_train_indices: NDArray[np.intp],
296
+ contiguous_segments: list[list[tuple[int, int, int, NDArray[np.intp]]]],
297
+ n_samples: int,
298
+ timestamps: pd.DatetimeIndex | None,
299
+ label_horizon: int | pd.Timedelta,
300
+ embargo_size: int | pd.Timedelta | None,
301
+ embargo_pct: float | None,
302
+ group_indices_list: list[NDArray[np.intp]] | None = None,
303
+ ) -> list[int]:
304
+ """Apply purging across all contiguous segments for an asset.
305
+
306
+ Uses a set-based approach for tracking remaining indices, which is
307
+ efficient for the iterative purging across segments.
308
+
309
+ Parameters
310
+ ----------
311
+ asset_train_indices : ndarray
312
+ Training indices for this asset.
313
+ contiguous_segments : list of list of tuple
314
+ Segments from find_contiguous_segments.
315
+ n_samples : int
316
+ Total number of samples.
317
+ timestamps : pd.DatetimeIndex, optional
318
+ Timestamps for time-based purging.
319
+ label_horizon : int or pd.Timedelta
320
+ Forward-looking period of labels.
321
+ embargo_size : int or pd.Timedelta, optional
322
+ Buffer period after test set.
323
+ embargo_pct : float, optional
324
+ Embargo as percentage of samples.
325
+ group_indices_list : list of ndarray, optional
326
+ Exact indices per group (for session-aligned mode).
327
+
328
+ Returns
329
+ -------
330
+ clean_indices : list of int
331
+ Sorted training indices after purging all segments.
332
+ """
333
+ remaining_train_indices = set(asset_train_indices)
334
+
335
+ for segment in contiguous_segments:
336
+ if not segment:
337
+ continue
338
+
339
+ # Compute purge window bounds
340
+ if group_indices_list is not None and timestamps is not None:
341
+ # Session-aligned mode: compute timestamp bounds from actual test indices
342
+ segment_test_indices = np.concatenate([item[3] for item in segment])
343
+ window = timestamp_window_from_indices(segment_test_indices, timestamps)
344
+ if window is None:
345
+ # Empty test segment - skip purging for this segment
346
+ continue
347
+ segment_start_time = window.start
348
+ segment_end_time = window.end_exclusive
349
+ else:
350
+ # Standard mode: use boundaries
351
+ segment_start_idx = segment[0][1] # Start of first group in segment
352
+ segment_end_idx = segment[-1][2] # End of last group in segment
353
+ segment_start_time, segment_end_time = convert_indices_to_timestamps(
354
+ segment_start_idx,
355
+ segment_end_idx,
356
+ timestamps,
357
+ )
358
+
359
+ # Apply purging for this contiguous segment
360
+ remaining_array = np.array(list(remaining_train_indices), dtype=np.intp)
361
+
362
+ if len(remaining_array) == 0:
363
+ break
364
+
365
+ clean_segment_train = apply_purging_and_embargo(
366
+ train_indices=remaining_array,
367
+ test_start=segment_start_time,
368
+ test_end=segment_end_time,
369
+ label_horizon=label_horizon,
370
+ embargo_size=embargo_size,
371
+ embargo_pct=embargo_pct,
372
+ n_samples=n_samples,
373
+ timestamps=timestamps,
374
+ )
375
+
376
+ # Update remaining indices (remove those that were purged)
377
+ remaining_train_indices = set(clean_segment_train)
378
+
379
+ return sorted(remaining_train_indices)