ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1064 @@
1
+ """Combinatorial Purged Cross-Validation for backtest overfitting detection.
2
+
3
+ This module implements Combinatorial Purged Cross-Validation (CPCV), which generates
4
+ multiple backtest paths by combining different groups of time-series data. This approach
5
+ provides a distribution of performance metrics instead of a single path, enabling robust
6
+ assessment of strategy viability and detection of backtest overfitting.
7
+
8
+ Key Concepts
9
+ ------------
10
+
11
+ **Combinatorial Splits**:
12
+ Instead of a single chronological train/test split, CPCV partitions data into N groups
13
+ and generates all C(N,k) combinations of choosing k groups for testing. This creates
14
+ a distribution of backtest results rather than a single path.
15
+
16
+ **Purging**:
17
+ Removes training samples that temporally overlap with test samples within the label
18
+ horizon. Essential for preventing information leakage when labels are forward-looking
19
+ (e.g., future returns). Without purging, the model could train on samples that contain
20
+ information about test set labels.
21
+
22
+ **Embargo**:
23
+ Creates a buffer period after each test group where training samples are removed.
24
+ Accounts for serial correlation in financial data and prevents training on samples
25
+ that are too close in time to the test set. Can be specified as absolute time
26
+ (embargo_size) or as a percentage of total samples (embargo_pct).
27
+
28
+ **Session Alignment**:
29
+ Optionally aligns group boundaries to trading session boundaries rather than arbitrary
30
+ indices. Ensures groups represent complete trading days/sessions, which is important
31
+ for intraday strategies.
32
+
33
+ **Multi-Asset Isolation**:
34
+ When groups parameter is provided, CPCV applies purging per asset independently.
35
+ This prevents cross-asset information leakage and enables proper validation of
36
+ multi-asset strategies.
37
+
38
+ Usage Example
39
+ -------------
40
+ Basic usage with purging and embargo::
41
+
42
+ import polars as pl
43
+ from ml4t.diagnostic.splitters import CombinatorialPurgedCV
44
+
45
+ # Load your time-series data
46
+ df = pl.read_parquet("features.parquet")
47
+ X = df.select(["feature1", "feature2", "feature3"])
48
+ y = df["target"]
49
+
50
+ # Configure CPCV with purging for 5-day forward labels
51
+ # and 2-day embargo to account for autocorrelation
52
+ cv = CombinatorialPurgedCV(
53
+ n_groups=8, # Split into 8 time groups
54
+ n_test_groups=2, # Use 2 groups for testing in each combination
55
+ label_horizon=5, # Labels look forward 5 samples
56
+ embargo_size=2, # Add 2-sample buffer after test set
57
+ max_combinations=20 # Limit to 20 combinations for efficiency
58
+ )
59
+
60
+ # Generate train/test splits
61
+ for fold, (train_idx, test_idx) in enumerate(cv.split(X)):
62
+ X_train, X_test = X[train_idx], X[test_idx]
63
+ y_train, y_test = y[train_idx], y[test_idx]
64
+
65
+ # Train and evaluate your model
66
+ model.fit(X_train, y_train)
67
+ score = model.score(X_test, y_test)
68
+ print(f"Fold {fold}: Score={score:.4f}")
69
+
70
+ Multi-asset usage with per-asset purging::
71
+
72
+ # For multi-asset strategies, provide asset IDs as groups
73
+ assets = df["symbol"] # e.g., ["AAPL", "MSFT", "GOOGL", ...]
74
+
75
+ cv = CombinatorialPurgedCV(
76
+ n_groups=6,
77
+ n_test_groups=2,
78
+ label_horizon=5,
79
+ embargo_size=2,
80
+ isolate_groups=True # Prevent same asset in train and test
81
+ )
82
+
83
+ for train_idx, test_idx in cv.split(X, groups=assets):
84
+ # CPCV automatically applies per-asset purging
85
+ # Each asset's data is purged independently
86
+ pass
87
+
88
+ Session-aligned usage for intraday strategies::
89
+
90
+ import pandas as pd
91
+
92
+ # Data with session_date column from qdata.sessions
93
+ df = pd.read_parquet("intraday_features.parquet")
94
+ # df has columns: timestamp, session_date, feature1, feature2, ...
95
+
96
+ cv = CombinatorialPurgedCV(
97
+ n_groups=10,
98
+ n_test_groups=2,
99
+ label_horizon=pd.Timedelta(minutes=30), # 30-minute forward labels
100
+ embargo_size=pd.Timedelta(minutes=15), # 15-minute embargo
101
+ align_to_sessions=True, # Align groups to sessions
102
+ session_col="session_date" # Column with session IDs
103
+ )
104
+
105
+ for train_idx, test_idx in cv.split(df):
106
+ # Group boundaries now align to complete trading sessions
107
+ pass
108
+
109
+ References
110
+ ----------
111
+ .. [1] Bailey, D. H., Borwein, J., López de Prado, M., & Zhu, Q. J. (2014).
112
+ "The Probability of Backtest Overfitting." Journal of Computational Finance.
113
+
114
+ .. [2] López de Prado, M. (2018). "Advances in Financial Machine Learning."
115
+ Wiley. Chapter 7: Cross-Validation in Finance.
116
+ """
117
+
118
+ from __future__ import annotations
119
+
120
+ import math
121
+ from collections.abc import Generator
122
+ from typing import TYPE_CHECKING, Any, cast
123
+
124
+ import numpy as np
125
+ import pandas as pd
126
+ import polars as pl
127
+
128
+ from ml4t.diagnostic.backends.adapter import DataFrameAdapter
129
+ from ml4t.diagnostic.splitters.base import BaseSplitter
130
+ from ml4t.diagnostic.splitters.config import CombinatorialPurgedConfig
131
+ from ml4t.diagnostic.splitters.cpcv import (
132
+ apply_multi_asset_purging,
133
+ apply_single_asset_purging,
134
+ create_contiguous_partitions,
135
+ create_session_partitions,
136
+ iter_combinations,
137
+ timestamp_window_from_indices,
138
+ validate_contiguous_partitions,
139
+ )
140
+ from ml4t.diagnostic.splitters.group_isolation import isolate_groups_from_train
141
+
142
+ if TYPE_CHECKING:
143
+ from numpy.typing import NDArray
144
+
145
+
146
+ class CombinatorialPurgedCV(BaseSplitter):
147
+ """Combinatorial Purged Cross-Validation for backtest overfitting detection.
148
+
149
+ CPCV partitions the time series into N contiguous groups and forms all combinations
150
+ C(N,k) of choosing k groups for testing. This generates multiple backtest paths
151
+ instead of a single chronological split, providing a robust assessment of strategy
152
+ performance and enabling detection of backtest overfitting.
153
+
154
+ How It Works
155
+ ------------
156
+
157
+ 1. **Partitioning**: Divide time-series data into N contiguous groups of equal size
158
+ 2. **Combination Generation**: Generate all C(N,k) combinations of choosing k groups for testing
159
+ 3. **Purging**: For each combination, remove training samples that overlap with test labels
160
+ 4. **Embargo**: Optionally add buffer periods after test groups to account for autocorrelation
161
+ 5. **Multi-Asset Handling**: When groups are provided, apply purging independently per asset
162
+
163
+ Purging Mechanics
164
+ -----------------
165
+
166
+ **Why Purge?**
167
+ When labels are forward-looking (e.g., 5-day returns), training samples near the test
168
+ set temporally overlap with test labels. Without purging, the model trains on information
169
+ about test outcomes, leading to inflated performance estimates.
170
+
171
+ **How Purging Works**:
172
+ For each test group with range [t_start, t_end]:
173
+
174
+ 1. Remove train samples where: ``t_train > t_start - label_horizon``
175
+ 2. This ensures no training sample's label period overlaps with test samples
176
+
177
+ **Example**::
178
+
179
+ Test group: samples 100-119 (20 samples)
180
+ Label horizon: 5 samples
181
+ Purging removes: training samples 95-99
182
+ Reason: Sample 95's label looks forward to sample 100 (first test sample)
183
+
184
+ Embargo Mechanics
185
+ -----------------
186
+
187
+ **Why Embargo?**
188
+ Financial data exhibits serial correlation - adjacent samples are not independent.
189
+ Even with purging, training on samples immediately before the test set can leak
190
+ information through autocorrelation.
191
+
192
+ **How Embargo Works**:
193
+ After purging, additionally remove a buffer of samples immediately after each test group:
194
+
195
+ - **embargo_size**: Absolute number of samples (e.g., 10 samples)
196
+ - **embargo_pct**: Percentage of total samples (e.g., 0.01 = 1%)
197
+
198
+ **Example**::
199
+
200
+ Test group: samples 100-119
201
+ Embargo: 5 samples
202
+ Additional removal: training samples 120-124
203
+ Result: Creates 5-sample buffer after test group
204
+
205
+ Multi-Asset Purging
206
+ -------------------
207
+
208
+ When ``groups`` parameter is provided (e.g., asset symbols), CPCV applies purging
209
+ independently for each asset. This prevents cross-asset leakage:
210
+
211
+ **Process**:
212
+ 1. For each asset, find its training and test samples
213
+ 2. Apply purging/embargo only to that asset's data
214
+ 3. Combine results across all assets
215
+
216
+ **Why Important?**
217
+ Without per-asset purging, information could leak between assets that trade at
218
+ different times (e.g., European markets vs US markets).
219
+
220
+ Based on Bailey et al. (2014) "The Probability of Backtest Overfitting" and
221
+ López de Prado (2018) "Advances in Financial Machine Learning".
222
+
223
+ Parameters
224
+ ----------
225
+ n_groups : int, default=8
226
+ Number of contiguous groups to partition the time series into.
227
+
228
+ n_test_groups : int, default=2
229
+ Number of groups to use for testing in each combination.
230
+
231
+ label_horizon : int or pd.Timedelta, default=0
232
+ Forward-looking period of labels for purging calculation.
233
+
234
+ embargo_size : int or pd.Timedelta, optional
235
+ Size of embargo period after each test group.
236
+
237
+ embargo_pct : float, optional
238
+ Embargo size as percentage of total samples.
239
+
240
+ max_combinations : int, optional
241
+ Maximum number of combinations to generate. If None, generates all C(N,k).
242
+ Use this to limit computational cost for large N.
243
+
244
+ random_state : int, optional
245
+ Random seed for combination sampling when max_combinations is set.
246
+
247
+ align_to_sessions : bool, default=False
248
+ If True, align group boundaries to trading session boundaries.
249
+ Requires X to have a session column (specified by session_col parameter).
250
+
251
+ Trading sessions should be assigned using the qdata library before cross-validation:
252
+ - Use DataManager with exchange/calendar parameters, or
253
+ - Use SessionAssigner.from_exchange('CME') directly
254
+
255
+ session_col : str, default='session_date'
256
+ Name of the column containing session identifiers.
257
+ Only used if align_to_sessions=True.
258
+ This column should be added by qdata.sessions.SessionAssigner
259
+
260
+ isolate_groups : bool, default=True
261
+ If True, prevent the same group (asset/symbol) from appearing in both
262
+ train and test sets. This is enabled by default for CPCV as it's designed
263
+ for multi-asset validation.
264
+
265
+ Requires passing `groups` parameter to split() method with asset IDs.
266
+
267
+ Note: CPCV already applies per-asset purging when groups are provided.
268
+ This parameter provides additional group isolation guarantee.
269
+
270
+ Attributes:
271
+ ----------
272
+ n_groups_ : int
273
+ The number of groups.
274
+
275
+ n_test_groups_ : int
276
+ The number of test groups.
277
+
278
+ Examples:
279
+ --------
280
+ >>> import numpy as np
281
+ >>> from ml4t.diagnostic.splitters import CombinatorialPurgedCV
282
+ >>> X = np.arange(200).reshape(200, 1)
283
+ >>> cv = CombinatorialPurgedCV(n_groups=6, n_test_groups=2, label_horizon=5)
284
+ >>> combinations = list(cv.split(X))
285
+ >>> print(f"Generated {len(combinations)} combinations")
286
+ Generated 15 combinations
287
+
288
+ >>> # Each combination provides train/test indices
289
+ >>> for i, (train, test) in enumerate(combinations[:3]):
290
+ ... print(f"Combination {i+1}: Train={len(train)}, Test={len(test)}")
291
+ Combination 1: Train=125, Test=50
292
+ Combination 2: Train=125, Test=50
293
+ Combination 3: Train=125, Test=50
294
+
295
+ Notes:
296
+ -----
297
+ The total number of combinations is C(n_groups, n_test_groups). For large values,
298
+ this can become computationally expensive:
299
+ - C(8,2) = 28 combinations
300
+ - C(10,3) = 120 combinations
301
+ - C(12,4) = 495 combinations
302
+
303
+ Use max_combinations to limit computational cost for large datasets.
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ config: CombinatorialPurgedConfig | None = None,
309
+ *,
310
+ n_groups: int = 8,
311
+ n_test_groups: int = 2,
312
+ label_horizon: int | pd.Timedelta = 0,
313
+ embargo_size: int | pd.Timedelta | None = None,
314
+ embargo_pct: float | None = None,
315
+ max_combinations: int | None = None,
316
+ random_state: int | None = None,
317
+ align_to_sessions: bool = False,
318
+ session_col: str = "session_date",
319
+ timestamp_col: str | None = None,
320
+ isolate_groups: bool = True,
321
+ ) -> None:
322
+ """Initialize CombinatorialPurgedCV.
323
+
324
+ This splitter uses a config-first architecture. You can either:
325
+ 1. Pass a config object: CombinatorialPurgedCV(config=my_config)
326
+ 2. Pass individual parameters: CombinatorialPurgedCV(n_groups=8, n_test_groups=2)
327
+
328
+ Parameters are automatically converted to a config object internally,
329
+ ensuring a single source of truth for all validation and logic.
330
+
331
+ Examples
332
+ --------
333
+ >>> # Approach 1: Direct parameters (convenient)
334
+ >>> cv = CombinatorialPurgedCV(n_groups=10, n_test_groups=3)
335
+ >>>
336
+ >>> # Approach 2: Config object (for serialization/reproducibility)
337
+ >>> from ml4t.diagnostic.splitters.config import CombinatorialPurgedConfig
338
+ >>> config = CombinatorialPurgedConfig(n_groups=10, n_test_groups=3)
339
+ >>> cv = CombinatorialPurgedCV(config=config)
340
+ >>>
341
+ >>> # Config can be serialized
342
+ >>> config.to_json("cpcv_config.json")
343
+ >>> loaded = CombinatorialPurgedConfig.from_json("cpcv_config.json")
344
+ >>> cv = CombinatorialPurgedCV(config=loaded)
345
+ """
346
+ # Config-first: either use provided config or create from params
347
+ if config is not None:
348
+ # Verify no conflicting parameters when config is provided
349
+ self._validate_no_param_conflicts(
350
+ n_groups,
351
+ n_test_groups,
352
+ label_horizon,
353
+ embargo_size,
354
+ embargo_pct,
355
+ max_combinations,
356
+ random_state,
357
+ align_to_sessions,
358
+ session_col,
359
+ timestamp_col,
360
+ isolate_groups,
361
+ )
362
+ self.config = config
363
+ else:
364
+ # Create config from individual parameters
365
+ # Note: embargo validation (mutual exclusivity) handled by config
366
+ self.config = self._create_config_from_params(
367
+ n_groups,
368
+ n_test_groups,
369
+ label_horizon,
370
+ embargo_size,
371
+ embargo_pct,
372
+ max_combinations,
373
+ random_state,
374
+ align_to_sessions,
375
+ session_col,
376
+ timestamp_col,
377
+ isolate_groups,
378
+ )
379
+
380
+ # Use parameter if provided, otherwise use config value
381
+ # This allows random_state to be passed either via config or direct parameter
382
+ self.random_state = random_state if random_state is not None else self.config.random_state
383
+
384
+ def _validate_no_param_conflicts(
385
+ self,
386
+ n_groups: int,
387
+ n_test_groups: int,
388
+ label_horizon: int | pd.Timedelta,
389
+ embargo_size: int | pd.Timedelta | None,
390
+ embargo_pct: float | None,
391
+ max_combinations: int | None,
392
+ random_state: int | None,
393
+ align_to_sessions: bool,
394
+ session_col: str,
395
+ timestamp_col: str | None,
396
+ isolate_groups: bool,
397
+ ) -> None:
398
+ """Validate no conflicting parameters when config is provided."""
399
+
400
+ def is_semantically_default(value: Any, default: Any) -> bool:
401
+ """Check if value is semantically equal to default.
402
+
403
+ Handles heterogeneous types:
404
+ - pd.Timedelta(0) is semantically equal to 0
405
+ - np.int64(0) is semantically equal to 0
406
+ - None equals None
407
+ """
408
+ if value is None and default is None:
409
+ return True
410
+ if value is None or default is None:
411
+ return False
412
+ # Handle Timedelta vs int comparison for label_horizon/embargo_size
413
+ if isinstance(value, pd.Timedelta):
414
+ if isinstance(default, int) and default == 0:
415
+ return value == pd.Timedelta(0)
416
+ return value == default
417
+ if isinstance(default, pd.Timedelta):
418
+ if isinstance(value, int) and value == 0:
419
+ return default == pd.Timedelta(0)
420
+ return value == default
421
+ # Handle numpy int types vs Python int
422
+ try:
423
+ return bool(value == default)
424
+ except (TypeError, ValueError):
425
+ return False
426
+
427
+ # Check for non-default parameter values
428
+ # Note: random_state is NOT in this list because it's now in config.
429
+ # Users can pass random_state as a parameter to override config.random_state.
430
+ param_checks = [
431
+ ("n_groups", n_groups, 8),
432
+ ("n_test_groups", n_test_groups, 2),
433
+ ("label_horizon", label_horizon, 0),
434
+ ("embargo_size", embargo_size, None),
435
+ ("embargo_pct", embargo_pct, None),
436
+ ("max_combinations", max_combinations, None),
437
+ ("align_to_sessions", align_to_sessions, False),
438
+ ("session_col", session_col, "session_date"),
439
+ ("timestamp_col", timestamp_col, None),
440
+ ("isolate_groups", isolate_groups, True),
441
+ ]
442
+
443
+ non_default_params = [
444
+ name
445
+ for name, value, default in param_checks
446
+ if not is_semantically_default(value, default)
447
+ ]
448
+
449
+ if non_default_params:
450
+ raise ValueError(
451
+ f"Cannot specify both 'config' and individual parameters. "
452
+ f"Got config plus: {', '.join(non_default_params)}"
453
+ )
454
+
455
+ def _create_config_from_params(
456
+ self,
457
+ n_groups: int,
458
+ n_test_groups: int,
459
+ label_horizon: int | pd.Timedelta,
460
+ embargo_size: int | pd.Timedelta | None,
461
+ embargo_pct: float | None,
462
+ max_combinations: int | None,
463
+ random_state: int | None,
464
+ align_to_sessions: bool,
465
+ session_col: str,
466
+ timestamp_col: str | None,
467
+ isolate_groups: bool,
468
+ ) -> CombinatorialPurgedConfig:
469
+ """Create config object from individual parameters."""
470
+ return CombinatorialPurgedConfig(
471
+ n_groups=n_groups,
472
+ n_test_groups=n_test_groups,
473
+ label_horizon=label_horizon,
474
+ embargo_td=embargo_size,
475
+ embargo_pct=embargo_pct,
476
+ max_combinations=max_combinations,
477
+ random_state=random_state,
478
+ align_to_sessions=align_to_sessions,
479
+ session_col=session_col,
480
+ timestamp_col=timestamp_col,
481
+ isolate_groups=isolate_groups,
482
+ )
483
+
484
+ # Property accessors for config values (clean API)
485
+ @property
486
+ def n_groups(self) -> int:
487
+ """Number of groups to partition timeline into."""
488
+ return self.config.n_groups
489
+
490
+ @property
491
+ def n_test_groups(self) -> int:
492
+ """Number of groups per test set."""
493
+ return self.config.n_test_groups
494
+
495
+ @property
496
+ def label_horizon(self) -> int | pd.Timedelta:
497
+ """Forward-looking period of labels (int samples or Timedelta)."""
498
+ return self.config.label_horizon
499
+
500
+ @property
501
+ def embargo_size(self) -> int | pd.Timedelta | None:
502
+ """Embargo buffer size (int samples or Timedelta)."""
503
+ return self.config.embargo_td
504
+
505
+ @property
506
+ def embargo_pct(self) -> float | None:
507
+ """Embargo size as percentage of total samples."""
508
+ return self.config.embargo_pct
509
+
510
+ @property
511
+ def max_combinations(self) -> int | None:
512
+ """Maximum number of folds to generate."""
513
+ return self.config.max_combinations
514
+
515
+ @property
516
+ def align_to_sessions(self) -> bool:
517
+ """Whether to align group boundaries to sessions."""
518
+ return self.config.align_to_sessions
519
+
520
+ @property
521
+ def session_col(self) -> str:
522
+ """Column name containing session identifiers."""
523
+ return self.config.session_col
524
+
525
+ @property
526
+ def timestamp_col(self) -> str | None:
527
+ """Column name containing timestamps for time-based operations."""
528
+ return self.config.timestamp_col
529
+
530
+ @property
531
+ def isolate_groups(self) -> bool:
532
+ """Whether to prevent group overlap between train/test."""
533
+ return self.config.isolate_groups
534
+
535
+ def get_n_splits(
536
+ self,
537
+ X: pl.DataFrame | pd.DataFrame | NDArray[Any] | None = None,
538
+ y: pl.Series | pd.Series | NDArray[Any] | None = None,
539
+ groups: pl.Series | pd.Series | NDArray[Any] | None = None,
540
+ ) -> int:
541
+ """Get number of splits (combinations).
542
+
543
+ Parameters
544
+ ----------
545
+ X : array-like, optional
546
+ Always ignored, exists for compatibility.
547
+
548
+ y : array-like, optional
549
+ Always ignored, exists for compatibility.
550
+
551
+ groups : array-like, optional
552
+ Always ignored, exists for compatibility.
553
+
554
+ Returns:
555
+ -------
556
+ n_splits : int
557
+ Number of combinations that will be generated.
558
+ """
559
+ del X, y, groups # Unused, for sklearn compatibility
560
+ total_combinations = math.comb(self.n_groups, self.n_test_groups)
561
+
562
+ if self.max_combinations is None:
563
+ return total_combinations
564
+ return min(self.max_combinations, total_combinations)
565
+
566
+ def split(
567
+ self,
568
+ X: pl.DataFrame | pd.DataFrame | NDArray[Any],
569
+ y: pl.Series | pd.Series | NDArray[Any] | None = None,
570
+ groups: pl.Series | pd.Series | NDArray[Any] | None = None,
571
+ ) -> Generator[tuple[NDArray[np.intp], NDArray[np.intp]], None, None]:
572
+ """Generate train/test indices for combinatorial splits with purging and embargo.
573
+
574
+ This method generates all combinations C(N,k) of train/test splits, applying
575
+ purging and embargo to prevent information leakage. Each yielded split represents
576
+ an independent backtest path.
577
+
578
+ Parameters
579
+ ----------
580
+ X : DataFrame or ndarray of shape (n_samples, n_features)
581
+ Training data. Must have a datetime index if using Timedelta-based
582
+ label_horizon or embargo_size.
583
+
584
+ y : Series or ndarray of shape (n_samples,), optional
585
+ Target variable. Not used in splitting logic, but accepted for
586
+ API compatibility with scikit-learn.
587
+
588
+ groups : Series or ndarray of shape (n_samples,), optional
589
+ Group labels for samples (e.g., asset symbols for multi-asset strategies).
590
+
591
+ When provided:
592
+ - Purging is applied independently per group (asset)
593
+ - Prevents information leakage across groups
594
+ - Essential for multi-asset portfolio validation
595
+
596
+ Example: ``groups = df["symbol"]`` # ["AAPL", "MSFT", "GOOGL", ...]
597
+
598
+ Yields
599
+ ------
600
+ train : ndarray of shape (n_train_samples,)
601
+ Indices of training samples for this combination.
602
+ Purging and embargo have been applied to remove:
603
+ - Samples overlapping with test labels (purging)
604
+ - Samples in embargo buffer after test groups (embargo)
605
+
606
+ test : ndarray of shape (n_test_samples,)
607
+ Indices of test samples for this combination.
608
+ Consists of samples from the k selected test groups.
609
+
610
+ Raises
611
+ ------
612
+ ValueError
613
+ If X has incompatible shape or missing required columns
614
+ (e.g., session_col when align_to_sessions=True).
615
+
616
+ TypeError
617
+ If X index is not datetime when using Timedelta parameters.
618
+
619
+ Notes
620
+ -----
621
+ **Number of Combinations**:
622
+ Generates C(n_groups, n_test_groups) combinations. For example:
623
+ - C(8,2) = 28 combinations
624
+ - C(10,3) = 120 combinations
625
+ - C(12,4) = 495 combinations
626
+
627
+ Use ``max_combinations`` parameter to limit the number of splits generated.
628
+
629
+ **Purging Logic**:
630
+ For each test group:
631
+ 1. Identify test sample range [t_start, t_end]
632
+ 2. Remove training samples where: t_train > t_start - label_horizon
633
+ 3. This prevents training on samples whose labels overlap with test period
634
+
635
+ **Embargo Logic**:
636
+ After purging, additionally remove training samples:
637
+ - In range [t_end + 1, t_end + embargo_size]
638
+ - This accounts for serial correlation in financial time series
639
+
640
+ **Multi-Asset Handling**:
641
+ When ``groups`` is provided:
642
+ 1. For each asset, find its training and test indices
643
+ 2. Apply purging/embargo independently to that asset's data
644
+ 3. Combine purged results across all assets
645
+ 4. This prevents cross-asset information leakage
646
+
647
+ **Session Alignment**:
648
+ When ``align_to_sessions=True``:
649
+ - Group boundaries align to trading session boundaries
650
+ - Ensures each group contains complete trading days/sessions
651
+ - Requires X to have column specified by ``session_col`` parameter
652
+
653
+ Examples
654
+ --------
655
+ Basic usage with purging::
656
+
657
+ >>> import polars as pl
658
+ >>> from ml4t.diagnostic.splitters import CombinatorialPurgedCV
659
+ >>>
660
+ >>> # Create sample data
661
+ >>> n = 1000
662
+ >>> X = pl.DataFrame({"feature1": range(n), "feature2": range(n, 2*n)})
663
+ >>> y = pl.Series(range(n))
664
+ >>>
665
+ >>> # Configure CPCV
666
+ >>> cv = CombinatorialPurgedCV(
667
+ ... n_groups=8,
668
+ ... n_test_groups=2,
669
+ ... label_horizon=5,
670
+ ... embargo_size=2
671
+ ... )
672
+ >>>
673
+ >>> # Generate splits
674
+ >>> for fold, (train_idx, test_idx) in enumerate(cv.split(X)):
675
+ ... print(f"Fold {fold}: Train={len(train_idx)}, Test={len(test_idx)}")
676
+ Fold 0: Train=739, Test=250
677
+ Fold 1: Train=739, Test=250
678
+ ...
679
+
680
+ Multi-asset usage::
681
+
682
+ >>> # Multi-asset data with symbol column
683
+ >>> symbols = pl.Series(["AAPL"] * 250 + ["MSFT"] * 250 +
684
+ ... ["GOOGL"] * 250 + ["AMZN"] * 250)
685
+ >>>
686
+ >>> cv = CombinatorialPurgedCV(
687
+ ... n_groups=6,
688
+ ... n_test_groups=2,
689
+ ... label_horizon=5,
690
+ ... embargo_size=2,
691
+ ... isolate_groups=True
692
+ ... )
693
+ >>>
694
+ >>> for train_idx, test_idx in cv.split(X, groups=symbols):
695
+ ... # Purging applied independently per asset
696
+ ... train_symbols = symbols[train_idx].unique()
697
+ ... test_symbols = symbols[test_idx].unique()
698
+
699
+ Session-aligned usage::
700
+
701
+ >>> import pandas as pd
702
+ >>>
703
+ >>> # Intraday data with session dates
704
+ >>> df = pd.DataFrame({
705
+ ... "timestamp": pd.date_range("2024-01-01", periods=1000, freq="1min"),
706
+ ... "session_date": pd.date_range("2024-01-01", periods=1000, freq="1min").date,
707
+ ... "feature1": range(1000)
708
+ ... })
709
+ >>>
710
+ >>> cv = CombinatorialPurgedCV(
711
+ ... n_groups=10,
712
+ ... n_test_groups=2,
713
+ ... label_horizon=pd.Timedelta(minutes=30),
714
+ ... embargo_size=pd.Timedelta(minutes=15),
715
+ ... align_to_sessions=True,
716
+ ... session_col="session_date"
717
+ ... )
718
+ >>>
719
+ >>> for train_idx, test_idx in cv.split(df):
720
+ ... # Group boundaries aligned to session boundaries
721
+ ... pass
722
+
723
+ See Also
724
+ --------
725
+ CombinatorialPurgedConfig : Configuration object for CPCV parameters
726
+ apply_purging_and_embargo : Low-level purging/embargo function
727
+ BaseSplitter : Base class for all splitters
728
+ """
729
+ # Validate inputs (no numpy conversion - performance optimization)
730
+ n_samples = self._validate_inputs(X, y, groups)
731
+
732
+ # Validate session alignment if enabled
733
+ self._validate_session_alignment(X, self.align_to_sessions, self.session_col)
734
+
735
+ # Extract timestamps if available (supports both Polars and pandas)
736
+ timestamps = self._extract_timestamps(X, self.timestamp_col)
737
+
738
+ # Create group indices or boundaries
739
+ # For session-aligned mode, we need exact indices (not boundaries) to handle
740
+ # non-contiguous/interleaved data correctly
741
+ if self.align_to_sessions:
742
+ # align_to_sessions requires X to be a DataFrame (validation enforces this)
743
+ # Use new method that returns exact indices per group
744
+ group_indices_list = self._create_session_group_indices(
745
+ cast(pl.DataFrame | pd.DataFrame, X)
746
+ )
747
+ use_exact_indices = True
748
+ # Also create boundaries for backward compatibility with purging logic
749
+ group_boundaries = [
750
+ (int(indices[0]), int(indices[-1]) + 1) if len(indices) > 0 else (0, 0)
751
+ for indices in group_indices_list
752
+ ]
753
+ else:
754
+ group_boundaries = self._create_group_boundaries(n_samples)
755
+ group_indices_list = None
756
+ use_exact_indices = False
757
+
758
+ # Generate combinations with memory-efficient sampling when max_combinations is set
759
+ # Uses reservoir sampling when needed to avoid materializing all C(n,k) combinations
760
+ combinations = iter_combinations(
761
+ self.n_groups,
762
+ self.n_test_groups,
763
+ self.max_combinations,
764
+ self.random_state,
765
+ )
766
+
767
+ # Generate splits for each combination
768
+ for test_group_indices in combinations:
769
+ # Create test set from selected groups
770
+ if use_exact_indices and group_indices_list is not None:
771
+ # Use exact indices (correct for non-contiguous/interleaved data)
772
+ test_arrays = [group_indices_list[g] for g in test_group_indices]
773
+ test_indices_array = (
774
+ np.concatenate(test_arrays) if test_arrays else np.array([], dtype=np.intp)
775
+ )
776
+ else:
777
+ # Use boundaries with range (only correct for contiguous data)
778
+ test_indices: list[int] = []
779
+ for group_idx in test_group_indices:
780
+ start_idx, end_idx = group_boundaries[group_idx]
781
+ test_indices.extend(range(start_idx, end_idx))
782
+ test_indices_array = np.array(test_indices, dtype=np.intp)
783
+
784
+ # Create initial training set from remaining groups
785
+ train_group_indices_list = [
786
+ i for i in range(self.n_groups) if i not in test_group_indices
787
+ ]
788
+ if use_exact_indices and group_indices_list is not None:
789
+ # Use exact indices
790
+ train_arrays = [group_indices_list[g] for g in train_group_indices_list]
791
+ train_indices_array = (
792
+ np.concatenate(train_arrays) if train_arrays else np.array([], dtype=np.intp)
793
+ )
794
+ else:
795
+ # Use boundaries with range
796
+ train_indices: list[int] = []
797
+ for group_idx in train_group_indices_list:
798
+ start_idx, end_idx = group_boundaries[group_idx]
799
+ train_indices.extend(range(start_idx, end_idx))
800
+ train_indices_array = np.array(train_indices, dtype=np.intp)
801
+
802
+ # Apply purging and embargo between test groups and training data
803
+ clean_train_indices = self._apply_group_purging_and_embargo(
804
+ train_indices_array,
805
+ test_group_indices,
806
+ group_boundaries,
807
+ n_samples,
808
+ timestamps,
809
+ groups, # Pass groups for multi-asset awareness
810
+ group_indices_list, # Pass exact indices for session-aligned purging
811
+ )
812
+
813
+ # Apply group isolation if requested
814
+ if self.isolate_groups and groups is not None:
815
+ clean_train_indices = isolate_groups_from_train(
816
+ clean_train_indices, test_indices_array, groups
817
+ )
818
+
819
+ # CPCV Invariant: train set must not be empty after purging
820
+ if len(clean_train_indices) == 0:
821
+ raise ValueError(
822
+ f"CPCV invariant violated: train set is empty after purging/embargo. "
823
+ f"Test groups: {test_group_indices}. "
824
+ f"Consider reducing label_horizon ({self.label_horizon}) or "
825
+ f"embargo_size ({self.embargo_size}) or embargo_pct ({self.embargo_pct})."
826
+ )
827
+
828
+ # CPCV Invariant: train and test sets must be disjoint
829
+ overlap = np.intersect1d(clean_train_indices, test_indices_array)
830
+ if len(overlap) > 0:
831
+ raise ValueError(
832
+ f"CPCV invariant violated: train and test sets have {len(overlap)} "
833
+ f"overlapping indices. First few: {overlap[:5].tolist()}"
834
+ )
835
+
836
+ # Return sorted indices for deterministic behavior
837
+ yield np.sort(clean_train_indices), np.sort(test_indices_array)
838
+
839
+ def _create_group_boundaries(self, n_samples: int) -> list[tuple[int, int]]:
840
+ """Create boundaries for contiguous groups.
841
+
842
+ Delegates to cpcv.partitioning.create_contiguous_partitions.
843
+
844
+ Parameters
845
+ ----------
846
+ n_samples : int
847
+ Total number of samples.
848
+
849
+ Returns:
850
+ -------
851
+ boundaries : list of tuple
852
+ List of (start_idx, end_idx) for each group.
853
+
854
+ Raises
855
+ ------
856
+ ValueError
857
+ If boundaries don't satisfy CPCV invariants.
858
+ """
859
+ return create_contiguous_partitions(n_samples, self.n_groups)
860
+
861
+ def _validate_group_boundaries(self, boundaries: list[tuple[int, int]], n_samples: int) -> None:
862
+ """Validate CPCV group boundary invariants.
863
+
864
+ Delegates to cpcv.partitioning.validate_contiguous_partitions.
865
+ """
866
+ validate_contiguous_partitions(boundaries, n_samples)
867
+
868
+ def _create_session_group_indices(
869
+ self,
870
+ X: pl.DataFrame | pd.DataFrame,
871
+ ) -> list[NDArray[np.intp]]:
872
+ """Create exact index arrays per group, aligned to session boundaries.
873
+
874
+ Delegates to cpcv.partitioning.create_session_partitions.
875
+
876
+ Unlike _create_group_boundaries which returns (start, end) ranges suitable
877
+ for contiguous data, this method returns EXACT index arrays for each group.
878
+ This is critical for correct behavior with non-contiguous or interleaved data.
879
+
880
+ Parameters
881
+ ----------
882
+ X : DataFrame
883
+ Data with session column.
884
+
885
+ Returns
886
+ -------
887
+ group_indices : list of np.ndarray
888
+ List of numpy arrays containing exact row indices for each group.
889
+ """
890
+ return create_session_partitions(
891
+ X, self.session_col, self.n_groups, self._session_to_indices
892
+ )
893
+
894
+ @staticmethod
895
+ def _timestamp_window_from_indices(
896
+ indices: NDArray[np.intp],
897
+ timestamps: pd.DatetimeIndex,
898
+ ) -> tuple[pd.Timestamp, pd.Timestamp] | None:
899
+ """Compute timestamp window from actual indices (for session-aligned purging).
900
+
901
+ Delegates to cpcv.windows.timestamp_window_from_indices.
902
+
903
+ Parameters
904
+ ----------
905
+ indices : ndarray
906
+ Row indices of test samples.
907
+ timestamps : pd.DatetimeIndex
908
+ Timestamps for all samples.
909
+
910
+ Returns
911
+ -------
912
+ tuple or None
913
+ (start_time, end_time_exclusive) if indices non-empty, None if empty.
914
+ """
915
+ window = timestamp_window_from_indices(indices, timestamps)
916
+ if window is None:
917
+ return None
918
+ return window.start, window.end_exclusive
919
+
920
+ def _apply_group_purging_and_embargo(
921
+ self,
922
+ train_indices: NDArray[np.intp],
923
+ test_group_indices: tuple[int, ...],
924
+ group_boundaries: list[tuple[int, int]],
925
+ n_samples: int,
926
+ timestamps: pd.DatetimeIndex | None,
927
+ groups: pl.Series | pd.Series | NDArray[Any] | None = None,
928
+ group_indices_list: list[NDArray[np.intp]] | None = None,
929
+ ) -> NDArray[np.intp]:
930
+ """Apply purging and embargo between test groups and training data.
931
+
932
+ This method handles both single-asset and multi-asset scenarios.
933
+ For multi-asset data, purging is applied per asset to prevent
934
+ cross-asset look-ahead bias.
935
+
936
+ Parameters
937
+ ----------
938
+ train_indices : ndarray
939
+ Initial training indices.
940
+
941
+ test_group_indices : tuple of int
942
+ Indices of groups used for testing.
943
+
944
+ group_boundaries : list of tuple
945
+ Boundaries of all groups (used for non-session-aligned mode).
946
+
947
+ n_samples : int
948
+ Total number of samples.
949
+
950
+ timestamps : pd.DatetimeIndex, optional
951
+ Timestamps for the data.
952
+
953
+ groups : array-like, optional
954
+ Group labels for multi-asset data (e.g., asset IDs).
955
+ If None, applies single-asset purging logic.
956
+
957
+ group_indices_list : list of ndarray, optional
958
+ Exact indices per group (for session-aligned mode). When provided
959
+ along with timestamps, purging uses actual timestamp bounds instead
960
+ of (min_idx, max_idx) boundaries.
961
+
962
+ Returns:
963
+ -------
964
+ clean_indices : ndarray
965
+ Training indices after purging and embargo.
966
+ """
967
+ if groups is None:
968
+ # Single-asset case: apply global purging
969
+ return self._apply_single_asset_purging(
970
+ train_indices,
971
+ test_group_indices,
972
+ group_boundaries,
973
+ n_samples,
974
+ timestamps,
975
+ group_indices_list,
976
+ )
977
+ # Multi-asset case: apply per-asset purging
978
+ return self._apply_multi_asset_purging(
979
+ train_indices,
980
+ test_group_indices,
981
+ group_boundaries,
982
+ n_samples,
983
+ timestamps,
984
+ groups,
985
+ group_indices_list,
986
+ )
987
+
988
+ def _apply_single_asset_purging(
989
+ self,
990
+ train_indices: NDArray[np.intp],
991
+ test_group_indices: tuple[int, ...],
992
+ group_boundaries: list[tuple[int, int]],
993
+ n_samples: int,
994
+ timestamps: pd.DatetimeIndex | None,
995
+ group_indices_list: list[NDArray[np.intp]] | None = None,
996
+ ) -> NDArray[np.intp]:
997
+ """Apply purging for single-asset data.
998
+
999
+ Delegates to cpcv.purge_engine.apply_single_asset_purging.
1000
+ """
1001
+ return apply_single_asset_purging(
1002
+ train_indices=train_indices,
1003
+ test_group_indices=test_group_indices,
1004
+ group_boundaries=group_boundaries,
1005
+ n_samples=n_samples,
1006
+ timestamps=timestamps,
1007
+ label_horizon=self.label_horizon,
1008
+ embargo_size=self.embargo_size,
1009
+ embargo_pct=self.embargo_pct,
1010
+ group_indices_list=group_indices_list,
1011
+ )
1012
+
1013
+ def _apply_multi_asset_purging(
1014
+ self,
1015
+ train_indices: NDArray[np.intp],
1016
+ test_group_indices: tuple[int, ...],
1017
+ group_boundaries: list[tuple[int, int]],
1018
+ n_samples: int,
1019
+ timestamps: pd.DatetimeIndex | None,
1020
+ groups: pl.Series | pd.Series | NDArray[Any],
1021
+ group_indices_list: list[NDArray[np.intp]] | None = None,
1022
+ ) -> NDArray[np.intp]:
1023
+ """Apply purging for multi-asset data with per-asset isolation.
1024
+
1025
+ Delegates to cpcv.purge_engine.apply_multi_asset_purging.
1026
+ """
1027
+ # Convert groups to numpy array for consistent indexing
1028
+ groups_array = DataFrameAdapter.to_numpy(groups).flatten()
1029
+
1030
+ return apply_multi_asset_purging(
1031
+ train_indices=train_indices,
1032
+ test_group_indices=test_group_indices,
1033
+ group_boundaries=group_boundaries,
1034
+ n_samples=n_samples,
1035
+ timestamps=timestamps,
1036
+ groups_array=groups_array,
1037
+ label_horizon=self.label_horizon,
1038
+ embargo_size=self.embargo_size,
1039
+ embargo_pct=self.embargo_pct,
1040
+ group_indices_list=group_indices_list,
1041
+ )
1042
+
1043
+ def _validate_inputs(
1044
+ self,
1045
+ X: pl.DataFrame | pd.DataFrame | NDArray[Any],
1046
+ y: pl.Series | pd.Series | NDArray[Any] | None = None,
1047
+ groups: pl.Series | pd.Series | NDArray[Any] | None = None,
1048
+ ) -> int:
1049
+ """Validate input shapes and return number of samples.
1050
+
1051
+ Unlike the previous implementation, this does NOT convert to numpy
1052
+ for performance - just validates shapes directly.
1053
+ """
1054
+ # Use base class validation (handles all input types efficiently)
1055
+ n_samples = self._validate_data(X, y, groups)
1056
+
1057
+ # Validate minimum samples per group
1058
+ min_samples_per_group = n_samples // self.n_groups
1059
+ if min_samples_per_group < 1:
1060
+ raise ValueError(
1061
+ f"Not enough samples ({n_samples}) for {self.n_groups} groups. Need at least {self.n_groups} samples.",
1062
+ )
1063
+
1064
+ return n_samples