ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,190 @@
1
+ """Time window computation for CPCV purging.
2
+
3
+ This module handles computing purge windows from test indices:
4
+ - Timestamp windows from exact indices
5
+ - Contiguous segment detection
6
+ - Window merging for efficient purging
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass
12
+ from typing import TYPE_CHECKING
13
+
14
+ import numpy as np
15
+ from numpy.typing import NDArray
16
+
17
+ if TYPE_CHECKING:
18
+ import pandas as pd
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class TimeWindow:
23
+ """A time window for purging, with exclusive end bound.
24
+
25
+ Attributes
26
+ ----------
27
+ start : pd.Timestamp
28
+ Start of the window (inclusive).
29
+ end_exclusive : pd.Timestamp
30
+ End of the window (exclusive).
31
+ """
32
+
33
+ start: pd.Timestamp
34
+ end_exclusive: pd.Timestamp
35
+
36
+
37
+ def timestamp_window_from_indices(
38
+ indices: NDArray[np.intp],
39
+ timestamps: pd.DatetimeIndex,
40
+ ) -> TimeWindow | None:
41
+ """Compute timestamp window from actual indices.
42
+
43
+ This is critical for correct purging in session-aligned mode. Instead of
44
+ using (min_row_idx, max_row_idx) boundaries which can span unrelated rows
45
+ in interleaved data, we compute the actual timestamp bounds from the test
46
+ indices.
47
+
48
+ Parameters
49
+ ----------
50
+ indices : ndarray
51
+ Row indices of test samples.
52
+ timestamps : pd.DatetimeIndex
53
+ Timestamps for all samples.
54
+
55
+ Returns
56
+ -------
57
+ TimeWindow or None
58
+ Window with (start_time, end_time_exclusive) if indices non-empty.
59
+ None if indices is empty (signals caller to skip purging).
60
+
61
+ Notes
62
+ -----
63
+ The end is made exclusive by adding 1 nanosecond. This handles the case
64
+ of duplicate timestamps at the boundary.
65
+
66
+ Examples
67
+ --------
68
+ >>> import pandas as pd
69
+ >>> import numpy as np
70
+ >>> timestamps = pd.date_range("2020-01-01", periods=10, freq="D", tz="UTC")
71
+ >>> indices = np.array([2, 3, 4])
72
+ >>> window = timestamp_window_from_indices(indices, timestamps)
73
+ >>> window.start
74
+ Timestamp('2020-01-03 00:00:00+0000', tz='UTC')
75
+ """
76
+ import pandas as pd
77
+
78
+ if len(indices) == 0:
79
+ # Empty indices - return None to signal callers to skip purging
80
+ return None
81
+
82
+ test_timestamps = timestamps.take(indices)
83
+ start_time = test_timestamps.min()
84
+ # Add 1 nanosecond to make end exclusive (handles duplicate timestamps)
85
+ end_time_exclusive = test_timestamps.max() + pd.Timedelta(1, "ns")
86
+ return TimeWindow(start=start_time, end_exclusive=end_time_exclusive)
87
+
88
+
89
+ def find_contiguous_segments(
90
+ test_groups_data: list[tuple[int, int, int, NDArray[np.intp] | None]],
91
+ asset_indices: NDArray[np.intp],
92
+ ) -> list[list[tuple[int, int, int, NDArray[np.intp]]]]:
93
+ """Find contiguous segments of test groups for a given asset.
94
+
95
+ Groups test data into contiguous segments based on temporal adjacency.
96
+ This allows applying one purge window per segment instead of per group,
97
+ which is more efficient and statistically correct.
98
+
99
+ Parameters
100
+ ----------
101
+ test_groups_data : list of tuple
102
+ Each tuple contains (group_idx, group_start, group_end, exact_indices).
103
+ exact_indices is non-None for session-aligned mode.
104
+ asset_indices : ndarray
105
+ Indices belonging to this asset.
106
+
107
+ Returns
108
+ -------
109
+ segments : list of list of tuple
110
+ Each segment is a list of (group_idx, start, end, asset_test_indices).
111
+ Segments are separated by gaps in the test groups.
112
+
113
+ Notes
114
+ -----
115
+ In session-aligned mode, exact_indices should be used instead of
116
+ generating indices via np.arange (which is wrong for interleaved data).
117
+ """
118
+ contiguous_segments: list[list[tuple[int, int, int, NDArray[np.intp]]]] = []
119
+ current_segment: list[tuple[int, int, int, NDArray[np.intp]]] = []
120
+
121
+ for group_idx, group_start, group_end, exact_indices in test_groups_data:
122
+ # Get test indices for this asset in this group
123
+ if exact_indices is not None:
124
+ # Session-aligned mode: use exact indices
125
+ group_test_indices = exact_indices
126
+ else:
127
+ # Standard mode: generate from boundaries
128
+ group_test_indices = np.arange(group_start, group_end)
129
+ asset_group_test_indices = np.intersect1d(group_test_indices, asset_indices)
130
+
131
+ if len(asset_group_test_indices) == 0:
132
+ # No test data for this asset in this group
133
+ if current_segment:
134
+ contiguous_segments.append(current_segment)
135
+ current_segment = []
136
+ continue
137
+
138
+ # Check if this group is contiguous with the previous segment
139
+ # current_segment[-1][2] is group_end (exclusive), gap exists if group_start > group_end
140
+ if current_segment and group_start > current_segment[-1][2]: # Gap detected
141
+ # Finish current segment and start new one
142
+ contiguous_segments.append(current_segment)
143
+ current_segment = [(group_idx, group_start, group_end, asset_group_test_indices)]
144
+ else:
145
+ # Add to current segment
146
+ current_segment.append((group_idx, group_start, group_end, asset_group_test_indices))
147
+
148
+ # Don't forget the last segment
149
+ if current_segment:
150
+ contiguous_segments.append(current_segment)
151
+
152
+ return contiguous_segments
153
+
154
+
155
+ def merge_windows(windows: list[TimeWindow]) -> list[TimeWindow]:
156
+ """Merge overlapping time windows.
157
+
158
+ This can reduce the number of purge operations when windows overlap,
159
+ and provides clearer semantics about what's being purged.
160
+
161
+ Parameters
162
+ ----------
163
+ windows : list of TimeWindow
164
+ Windows to merge.
165
+
166
+ Returns
167
+ -------
168
+ merged : list of TimeWindow
169
+ Non-overlapping windows covering the same time ranges.
170
+ """
171
+ if not windows:
172
+ return []
173
+
174
+ # Sort by start time
175
+ sorted_windows = sorted(windows, key=lambda w: w.start)
176
+ merged = [sorted_windows[0]]
177
+
178
+ for window in sorted_windows[1:]:
179
+ last = merged[-1]
180
+ if window.start <= last.end_exclusive:
181
+ # Overlapping - merge by extending end
182
+ merged[-1] = TimeWindow(
183
+ start=last.start,
184
+ end_exclusive=max(last.end_exclusive, window.end_exclusive),
185
+ )
186
+ else:
187
+ # Non-overlapping - add new window
188
+ merged.append(window)
189
+
190
+ return merged
@@ -0,0 +1,329 @@
1
+ """Group isolation utilities for multi-asset cross-validation.
2
+
3
+ This module provides utilities to prevent the same asset (e.g., contract, symbol)
4
+ from appearing in both training and test sets during cross-validation. This is
5
+ critical for avoiding data leakage in multi-asset strategies.
6
+
7
+ Example Use Cases
8
+ -----------------
9
+ 1. **Futures contracts**: Prevent ES_202312 from being in both train and test
10
+ 2. **Multiple symbols**: Ensure AAPL data doesn't leak between folds
11
+ 3. **Multi-strategy**: Isolate strategies to prevent cross-contamination
12
+
13
+ Integration with qdata
14
+ ----------------------
15
+ The `groups` parameter should contain asset identifiers that come from your
16
+ data pipeline. Typically this would be a column like 'symbol', 'contract',
17
+ or 'asset_id' from your DataFrame.
18
+
19
+ Example::
20
+
21
+ import polars as pl
22
+ from ml4t.diagnostic.splitters import PurgedWalkForwardCV
23
+
24
+ # Data with asset identifiers
25
+ df = pl.DataFrame({
26
+ 'timestamp': [...],
27
+ 'symbol': ['AAPL', 'AAPL', 'MSFT', 'MSFT', ...],
28
+ 'returns': [...]
29
+ })
30
+
31
+ # Cross-validate with group isolation
32
+ cv = PurgedWalkForwardCV(n_splits=5, isolate_groups=True)
33
+
34
+ for train_idx, test_idx in cv.split(df, groups=df['symbol']):
35
+ # Groups in test_idx will NEVER appear in train_idx
36
+ train_symbols = df[train_idx]['symbol'].unique()
37
+ test_symbols = df[test_idx]['symbol'].unique()
38
+ assert len(set(train_symbols) & set(test_symbols)) == 0
39
+ """
40
+
41
+ from __future__ import annotations
42
+
43
+ from typing import TYPE_CHECKING, Any
44
+
45
+ import numpy as np
46
+ import pandas as pd
47
+ import polars as pl
48
+
49
+ from ml4t.diagnostic.backends.adapter import DataFrameAdapter
50
+
51
+ if TYPE_CHECKING:
52
+ from numpy.typing import NDArray
53
+
54
+
55
+ def validate_group_isolation(
56
+ train_indices: NDArray[np.intp],
57
+ test_indices: NDArray[np.intp],
58
+ groups: pl.Series | pd.Series | NDArray[Any],
59
+ ) -> tuple[bool, set]:
60
+ """Validate that train and test sets have no overlapping groups.
61
+
62
+ Parameters
63
+ ----------
64
+ train_indices : ndarray
65
+ Training set indices.
66
+
67
+ test_indices : ndarray
68
+ Test set indices.
69
+
70
+ groups : array-like
71
+ Group labels for each sample.
72
+
73
+ Returns
74
+ -------
75
+ is_valid : bool
76
+ True if no groups overlap between train and test.
77
+
78
+ overlapping_groups : set
79
+ Set of group IDs that appear in both train and test.
80
+ Empty if is_valid=True.
81
+
82
+ Examples
83
+ --------
84
+ >>> import numpy as np
85
+ >>> train_idx = np.array([0, 1, 2, 3])
86
+ >>> test_idx = np.array([4, 5, 6, 7])
87
+ >>> groups = np.array(['A', 'A', 'B', 'B', 'C', 'C', 'D', 'D'])
88
+ >>> is_valid, overlap = validate_group_isolation(train_idx, test_idx, groups)
89
+ >>> assert is_valid # Groups don't overlap
90
+ >>> assert len(overlap) == 0
91
+ """
92
+ # Convert groups to numpy array
93
+ groups_array = DataFrameAdapter.to_numpy(groups).flatten()
94
+
95
+ # Get unique groups in train and test
96
+ train_groups = set(groups_array[train_indices])
97
+ test_groups = set(groups_array[test_indices])
98
+
99
+ # Find overlap
100
+ overlapping_groups = train_groups & test_groups
101
+
102
+ return len(overlapping_groups) == 0, overlapping_groups
103
+
104
+
105
+ def isolate_groups_from_train(
106
+ train_indices: NDArray[np.intp],
107
+ test_indices: NDArray[np.intp],
108
+ groups: pl.Series | pd.Series | NDArray[Any],
109
+ ) -> NDArray[np.intp]:
110
+ """Remove samples from training set that share groups with test set.
111
+
112
+ This function ensures strict group isolation by removing all training
113
+ samples whose group appears anywhere in the test set.
114
+
115
+ Parameters
116
+ ----------
117
+ train_indices : ndarray
118
+ Initial training set indices.
119
+
120
+ test_indices : ndarray
121
+ Test set indices.
122
+
123
+ groups : array-like
124
+ Group labels for each sample.
125
+
126
+ Returns
127
+ -------
128
+ clean_train_indices : ndarray
129
+ Training indices with test groups removed.
130
+
131
+ Examples
132
+ --------
133
+ >>> import numpy as np
134
+ >>> train_idx = np.array([0, 1, 2, 3, 4, 5])
135
+ >>> test_idx = np.array([6, 7])
136
+ >>> groups = np.array(['A', 'A', 'B', 'B', 'C', 'C', 'C', 'C'])
137
+ >>> clean_train = isolate_groups_from_train(train_idx, test_idx, groups)
138
+ >>> # Removes indices 4,5 because they share group 'C' with test indices 6,7
139
+ >>> assert all(groups[clean_train] != 'C')
140
+
141
+ Notes
142
+ -----
143
+ This can significantly reduce training set size if groups are imbalanced.
144
+ Consider using group-aware splitting strategies to maintain balanced folds.
145
+ """
146
+ # Convert groups to numpy array
147
+ groups_array = DataFrameAdapter.to_numpy(groups).flatten()
148
+
149
+ # Get unique groups in test set
150
+ test_groups = set(groups_array[test_indices])
151
+
152
+ # Filter train indices to exclude any samples from test groups
153
+ clean_train_mask = np.array([groups_array[idx] not in test_groups for idx in train_indices])
154
+
155
+ return train_indices[clean_train_mask]
156
+
157
+
158
+ def get_group_boundaries(
159
+ groups: pl.Series | pd.Series | NDArray[Any],
160
+ sorted_indices: NDArray[np.intp] | None = None,
161
+ ) -> dict[Any, tuple[int, int]]:
162
+ """Get start and end indices for each unique group in sorted data.
163
+
164
+ This is useful for group-aware splitting where you want to keep groups
165
+ contiguous and avoid splitting a group across train/test boundaries.
166
+
167
+ Parameters
168
+ ----------
169
+ groups : array-like
170
+ Group labels for each sample.
171
+
172
+ sorted_indices : ndarray, optional
173
+ Pre-sorted indices. If None, assumes data is already sorted by group.
174
+
175
+ Returns
176
+ -------
177
+ boundaries : dict
178
+ Mapping from group ID to (start_idx, end_idx) tuple.
179
+
180
+ Examples
181
+ --------
182
+ >>> import numpy as np
183
+ >>> groups = np.array(['A', 'A', 'A', 'B', 'B', 'C'])
184
+ >>> boundaries = get_group_boundaries(groups)
185
+ >>> assert boundaries['A'] == (0, 3)
186
+ >>> assert boundaries['B'] == (3, 5)
187
+ >>> assert boundaries['C'] == (5, 6)
188
+
189
+ Notes
190
+ -----
191
+ This assumes groups are contiguous in the data. If groups are interleaved,
192
+ provide `sorted_indices` to ensure correct boundary detection.
193
+ """
194
+ # Convert groups to numpy array
195
+ groups_array = DataFrameAdapter.to_numpy(groups).flatten()
196
+
197
+ # Apply sorting if provided
198
+ if sorted_indices is not None:
199
+ groups_array = groups_array[sorted_indices]
200
+
201
+ # Find boundaries using change detection
202
+ boundaries = {}
203
+ unique_groups = []
204
+ current_group = None
205
+ start_idx = 0
206
+
207
+ for i, group_id in enumerate(groups_array):
208
+ if group_id != current_group:
209
+ # Group changed - record previous group's boundary
210
+ if current_group is not None:
211
+ boundaries[current_group] = (start_idx, i)
212
+
213
+ # Start new group
214
+ current_group = group_id
215
+ start_idx = i
216
+ unique_groups.append(group_id)
217
+
218
+ # Don't forget the last group
219
+ if current_group is not None:
220
+ boundaries[current_group] = (start_idx, len(groups_array))
221
+
222
+ return boundaries
223
+
224
+
225
+ def split_by_groups(
226
+ n_samples: int,
227
+ groups: pl.Series | pd.Series | NDArray[Any],
228
+ test_group_indices: list[int],
229
+ all_group_ids: list[Any],
230
+ ) -> tuple[NDArray[np.intp], NDArray[np.intp]]:
231
+ """Split samples into train/test based on group assignments.
232
+
233
+ This creates a complete split where all samples from specified test groups
234
+ go to the test set, and all other samples go to the training set.
235
+
236
+ Parameters
237
+ ----------
238
+ n_samples : int
239
+ Total number of samples.
240
+
241
+ groups : array-like
242
+ Group labels for each sample.
243
+
244
+ test_group_indices : list of int
245
+ Indices into `all_group_ids` specifying which groups go to test.
246
+
247
+ all_group_ids : list
248
+ Sorted list of all unique group IDs in the dataset.
249
+
250
+ Returns
251
+ -------
252
+ train_indices : ndarray
253
+ Indices of samples in training set.
254
+
255
+ test_indices : ndarray
256
+ Indices of samples in test set.
257
+
258
+ Examples
259
+ --------
260
+ >>> import numpy as np
261
+ >>> groups = np.array(['A', 'A', 'B', 'B', 'C', 'C'])
262
+ >>> all_groups = ['A', 'B', 'C']
263
+ >>> train_idx, test_idx = split_by_groups(
264
+ ... n_samples=6,
265
+ ... groups=groups,
266
+ ... test_group_indices=[2], # Group 'C'
267
+ ... all_group_ids=all_groups
268
+ ... )
269
+ >>> assert set(groups[train_idx]) == {'A', 'B'}
270
+ >>> assert set(groups[test_idx]) == {'C'}
271
+ """
272
+ # Convert groups to numpy array
273
+ groups_array = DataFrameAdapter.to_numpy(groups).flatten()
274
+
275
+ # Get test group IDs
276
+ test_group_ids = {all_group_ids[i] for i in test_group_indices}
277
+
278
+ # Create masks
279
+ test_mask = np.isin(groups_array, list(test_group_ids))
280
+ train_mask = ~test_mask
281
+
282
+ # Get indices
283
+ train_indices = np.where(train_mask)[0].astype(np.intp)
284
+ test_indices = np.where(test_mask)[0].astype(np.intp)
285
+
286
+ return train_indices, test_indices
287
+
288
+
289
+ def count_samples_per_group(
290
+ groups: pl.Series | pd.Series | NDArray[Any],
291
+ ) -> dict[Any, int]:
292
+ """Count number of samples for each unique group.
293
+
294
+ Useful for understanding group distribution and detecting imbalanced groups.
295
+
296
+ Parameters
297
+ ----------
298
+ groups : array-like
299
+ Group labels for each sample.
300
+
301
+ Returns
302
+ -------
303
+ counts : dict
304
+ Mapping from group ID to sample count.
305
+
306
+ Examples
307
+ --------
308
+ >>> import numpy as np
309
+ >>> groups = np.array(['A', 'A', 'A', 'B', 'B', 'C'])
310
+ >>> counts = count_samples_per_group(groups)
311
+ >>> assert counts == {'A': 3, 'B': 2, 'C': 1}
312
+ """
313
+ # Convert groups to numpy array
314
+ groups_array = DataFrameAdapter.to_numpy(groups).flatten()
315
+
316
+ # Count using numpy unique
317
+ unique_groups, counts = np.unique(groups_array, return_counts=True)
318
+
319
+ return dict(zip(unique_groups, counts, strict=False))
320
+
321
+
322
+ # Make functions available at module level
323
+ __all__ = [
324
+ "validate_group_isolation",
325
+ "isolate_groups_from_train",
326
+ "get_group_boundaries",
327
+ "split_by_groups",
328
+ "count_samples_per_group",
329
+ ]