ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,757 @@
1
+ """Walk-forward cross-validation with purging and embargo.
2
+
3
+ This module implements walk-forward cross-validation that prevents data leakage
4
+ through purging and embargo, suitable for time-series financial data.
5
+ """
6
+
7
+ from collections.abc import Generator
8
+ from typing import TYPE_CHECKING, Any, Union, cast
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import polars as pl
13
+
14
+ from ml4t.diagnostic.core.purging import apply_purging_and_embargo
15
+ from ml4t.diagnostic.splitters.base import BaseSplitter
16
+ from ml4t.diagnostic.splitters.calendar import TradingCalendar, parse_time_size_calendar_aware
17
+ from ml4t.diagnostic.splitters.calendar_config import CalendarConfig
18
+ from ml4t.diagnostic.splitters.config import PurgedWalkForwardConfig
19
+ from ml4t.diagnostic.splitters.group_isolation import isolate_groups_from_train
20
+ from ml4t.diagnostic.splitters.utils import convert_indices_to_timestamps
21
+
22
+ if TYPE_CHECKING:
23
+ from numpy.typing import NDArray
24
+
25
+
26
+ class PurgedWalkForwardCV(BaseSplitter):
27
+ """Walk-forward cross-validator with purging and embargo.
28
+
29
+ Walk-forward CV creates sequential train/test splits where training data
30
+ always precedes test data. This implementation adds purging and embargo
31
+ to prevent data leakage from label overlap and serial correlation.
32
+
33
+ Parameters
34
+ ----------
35
+ n_splits : int, default=5
36
+ Number of splits to generate.
37
+
38
+ test_size : int, float, str, or None, optional
39
+ Size of each test set:
40
+ - If int: number of samples (e.g., 1000)
41
+ - If float: proportion of dataset (e.g., 0.1)
42
+ - If str: time period using pandas offset aliases (e.g., "4W", "30D", "3M")
43
+ - If None: uses 1 / (n_splits + 1)
44
+ Time-based specifications require X to have a DatetimeIndex.
45
+
46
+ train_size : int, float, str, or None, optional
47
+ Size of each training set:
48
+ - If int: number of samples (e.g., 10000)
49
+ - If float: proportion of dataset (e.g., 0.5)
50
+ - If str: time period using pandas offset aliases (e.g., "78W", "6M", "2Y")
51
+ - If None: uses all available data before test set
52
+ Time-based specifications require X to have a DatetimeIndex.
53
+
54
+ gap : int, default=0
55
+ Gap between training and test set (in addition to purging).
56
+
57
+ label_horizon : int or pd.Timedelta, default=0
58
+ Forward-looking period of labels for purging calculation.
59
+
60
+ embargo_size : int or pd.Timedelta, optional
61
+ Size of embargo period after each test set.
62
+
63
+ embargo_pct : float, optional
64
+ Embargo size as percentage of total samples.
65
+
66
+ expanding : bool, default=True
67
+ If True, training window expands with each split.
68
+ If False, uses fixed-size rolling window.
69
+
70
+ consecutive : bool, default=False
71
+ If True, uses consecutive (back-to-back) test periods with no gaps.
72
+ This is appropriate for walk-forward validation where you want to
73
+ simulate realistic trading with sequential validation periods.
74
+ If False, spreads test periods across the dataset to sample different
75
+ time periods (useful for testing robustness across market regimes).
76
+
77
+ calendar : str, CalendarConfig, or TradingCalendar, optional
78
+ Trading calendar for calendar-aware time period calculations.
79
+ - If str: Name of pandas_market_calendars calendar (e.g., 'CME_Equity', 'NYSE')
80
+ Creates default CalendarConfig with UTC timezone
81
+ - If CalendarConfig: Full configuration with exchange, timezone, and options
82
+ - If TradingCalendar: Pre-configured calendar instance
83
+ - If None: Uses naive time-based calculation (backward compatible)
84
+
85
+ For intraday data with time-based test_size/train_size (e.g., '4W'),
86
+ using a calendar ensures proper session-aware splitting:
87
+ - Trading sessions are atomic units (won't split Sunday 5pm - Friday 4pm)
88
+ - Handles varying data density in activity-based data (dollar bars, trade bars)
89
+ - Proper timezone handling for tz-naive and tz-aware data
90
+ - '1D' selections: Complete trading sessions
91
+ - '4W' selections: Complete trading weeks (e.g., 4 weeks of 5 sessions each)
92
+
93
+ Examples:
94
+ >>> from ml4t.diagnostic.splitters.calendar_config import CME_CONFIG
95
+ >>> cv = PurgedWalkForwardCV(test_size='4W', calendar=CME_CONFIG) # CME futures
96
+ >>> cv = PurgedWalkForwardCV(test_size='1W', calendar='NYSE') # US equities (simple)
97
+
98
+ align_to_sessions : bool, default=False
99
+ If True, align fold boundaries to trading session boundaries.
100
+ Requires X to have a session column (specified by session_col parameter).
101
+
102
+ Trading sessions should be assigned using the qdata library before cross-validation:
103
+ - Use DataManager with exchange/calendar parameters, or
104
+ - Use SessionAssigner.from_exchange('CME') directly
105
+
106
+ When enabled, fold boundaries will never split a trading session, preventing
107
+ subtle lookahead bias in intraday strategies.
108
+
109
+ session_col : str, default='session_date'
110
+ Name of the column containing session identifiers.
111
+ Only used if align_to_sessions=True.
112
+ This column should be added by qdata.sessions.SessionAssigner
113
+
114
+ isolate_groups : bool, default=False
115
+ If True, prevent the same group (asset/symbol) from appearing in both
116
+ train and test sets. This is critical for multi-asset validation to
117
+ avoid data leakage.
118
+
119
+ Requires passing `groups` parameter to split() method with asset IDs.
120
+
121
+ Example:
122
+ >>> cv = PurgedWalkForwardCV(n_splits=5, isolate_groups=True)
123
+ >>> for train, test in cv.split(df, groups=df['symbol']):
124
+ ... # train and test will have completely different symbols
125
+ ... pass
126
+
127
+ Attributes:
128
+ ----------
129
+ n_splits_ : int
130
+ The number of splits.
131
+
132
+ Examples:
133
+ --------
134
+ >>> import numpy as np
135
+ >>> from ml4t.diagnostic.splitters import PurgedWalkForwardCV
136
+ >>> X = np.arange(100).reshape(100, 1)
137
+ >>> cv = PurgedWalkForwardCV(n_splits=3, label_horizon=5, embargo_size=2)
138
+ >>> for train, test in cv.split(X):
139
+ ... print(f"Train: {len(train)}, Test: {len(test)}")
140
+ Train: 17, Test: 25
141
+ Train: 40, Test: 25
142
+ Train: 63, Test: 25
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ config: PurgedWalkForwardConfig | None = None,
148
+ *,
149
+ n_splits: int = 5,
150
+ test_size: float | None = None,
151
+ train_size: float | None = None,
152
+ gap: int = 0,
153
+ label_horizon: int | pd.Timedelta = 0,
154
+ embargo_size: int | pd.Timedelta | None = None,
155
+ embargo_pct: float | None = None,
156
+ expanding: bool = True,
157
+ consecutive: bool = False,
158
+ calendar: str | CalendarConfig | TradingCalendar | None = None,
159
+ align_to_sessions: bool = False,
160
+ session_col: str = "session_date",
161
+ timestamp_col: str | None = None,
162
+ isolate_groups: bool = False,
163
+ ) -> None:
164
+ """Initialize PurgedWalkForwardCV.
165
+
166
+ This splitter uses a config-first architecture. You can either:
167
+ 1. Pass a config object: PurgedWalkForwardCV(config=my_config)
168
+ 2. Pass individual parameters: PurgedWalkForwardCV(n_splits=5, test_size=100)
169
+
170
+ Parameters are automatically converted to a config object internally,
171
+ ensuring a single source of truth for all validation and logic.
172
+
173
+ Examples
174
+ --------
175
+ >>> # Approach 1: Direct parameters (convenient)
176
+ >>> cv = PurgedWalkForwardCV(n_splits=5, test_size=100)
177
+ >>>
178
+ >>> # Approach 2: Config object (for serialization/reproducibility)
179
+ >>> from ml4t.diagnostic.splitters.config import PurgedWalkForwardConfig
180
+ >>> config = PurgedWalkForwardConfig(n_splits=5, test_size=100)
181
+ >>> cv = PurgedWalkForwardCV(config=config)
182
+ >>>
183
+ >>> # Config can be serialized
184
+ >>> config.to_json("cv_config.json")
185
+ >>> loaded = PurgedWalkForwardConfig.from_json("cv_config.json")
186
+ >>> cv = PurgedWalkForwardCV(config=loaded)
187
+ """
188
+ # Config-first: either use provided config or create from params
189
+ if config is not None:
190
+ # Explicit config provided
191
+ # Verify no conflicting parameters were passed
192
+ non_default_params = []
193
+ if n_splits != 5:
194
+ non_default_params.append("n_splits")
195
+ if test_size is not None:
196
+ non_default_params.append("test_size")
197
+ if train_size is not None:
198
+ non_default_params.append("train_size")
199
+ if gap != 0:
200
+ non_default_params.append("gap")
201
+ if label_horizon != 0:
202
+ non_default_params.append("label_horizon")
203
+ if embargo_size is not None:
204
+ non_default_params.append("embargo_size")
205
+ if embargo_pct is not None:
206
+ non_default_params.append("embargo_pct")
207
+ if not expanding:
208
+ non_default_params.append("expanding")
209
+ if consecutive:
210
+ non_default_params.append("consecutive")
211
+ if calendar is not None:
212
+ non_default_params.append("calendar")
213
+ if align_to_sessions:
214
+ non_default_params.append("align_to_sessions")
215
+ if session_col != "session_date":
216
+ non_default_params.append("session_col")
217
+ if timestamp_col is not None:
218
+ non_default_params.append("timestamp_col")
219
+ if isolate_groups:
220
+ non_default_params.append("isolate_groups")
221
+
222
+ if non_default_params:
223
+ raise ValueError(
224
+ f"Cannot specify both 'config' and individual parameters. "
225
+ f"Got config plus: {', '.join(non_default_params)}"
226
+ )
227
+
228
+ self.config = config
229
+ else:
230
+ # Create config from individual parameters
231
+ # Note: embargo_size maps to embargo_td in config
232
+ self.config = PurgedWalkForwardConfig(
233
+ n_splits=n_splits,
234
+ test_size=test_size,
235
+ train_size=train_size,
236
+ label_horizon=label_horizon,
237
+ embargo_td=embargo_size,
238
+ align_to_sessions=align_to_sessions,
239
+ session_col=session_col,
240
+ timestamp_col=timestamp_col,
241
+ isolate_groups=isolate_groups,
242
+ )
243
+
244
+ # Handle calendar initialization
245
+ # NOTE: Calendar config could be moved to WalkForwardConfig in future version
246
+ if calendar is None:
247
+ self.calendar = None
248
+ elif isinstance(calendar, str | CalendarConfig):
249
+ self.calendar = TradingCalendar(calendar)
250
+ elif isinstance(calendar, TradingCalendar):
251
+ self.calendar = calendar
252
+ else:
253
+ raise TypeError(
254
+ f"calendar must be str, CalendarConfig, TradingCalendar, or None, got {type(calendar)}"
255
+ )
256
+
257
+ # Legacy attributes for compatibility with existing split() implementation
258
+ # These reference the config values
259
+ self.gap = gap
260
+ self.embargo_pct = embargo_pct
261
+ self.expanding = expanding
262
+ self.consecutive = consecutive
263
+
264
+ # Property accessors for config values (clean API)
265
+ @property
266
+ def n_splits(self) -> int:
267
+ """Number of cross-validation folds."""
268
+ return self.config.n_splits
269
+
270
+ @property
271
+ def test_size(self) -> int | float | str | None:
272
+ """Test set size specification."""
273
+ return self.config.test_size
274
+
275
+ @property
276
+ def train_size(self) -> int | float | str | None:
277
+ """Training set size specification."""
278
+ return self.config.train_size
279
+
280
+ @property
281
+ def label_horizon(self) -> int:
282
+ """Forward-looking period of labels."""
283
+ return self.config.label_horizon
284
+
285
+ @property
286
+ def embargo_size(self) -> int | None:
287
+ """Embargo buffer size."""
288
+ return self.config.embargo_td
289
+
290
+ @property
291
+ def align_to_sessions(self) -> bool:
292
+ """Whether to align fold boundaries to sessions."""
293
+ return self.config.align_to_sessions
294
+
295
+ @property
296
+ def session_col(self) -> str:
297
+ """Column name containing session identifiers."""
298
+ return self.config.session_col
299
+
300
+ @property
301
+ def timestamp_col(self) -> str | None:
302
+ """Column name containing timestamps for time-based sizes."""
303
+ return self.config.timestamp_col
304
+
305
+ @property
306
+ def isolate_groups(self) -> bool:
307
+ """Whether to prevent group overlap between train/test."""
308
+ return self.config.isolate_groups
309
+
310
+ def _parse_time_size(
311
+ self,
312
+ size_spec: int | float | str,
313
+ timestamps: pd.DatetimeIndex | None,
314
+ n_samples: int,
315
+ ) -> int:
316
+ """Parse size specification and convert to sample count.
317
+
318
+ Uses calendar-aware logic if calendar is configured, otherwise falls back
319
+ to naive time-based calculation.
320
+
321
+ Parameters
322
+ ----------
323
+ size_spec : int, float, or str
324
+ Size specification to parse.
325
+ timestamps : pd.DatetimeIndex
326
+ Datetime index of the data.
327
+ n_samples : int
328
+ Total number of samples in dataset.
329
+
330
+ Returns
331
+ -------
332
+ int
333
+ Number of samples corresponding to the size specification.
334
+ """
335
+ if isinstance(size_spec, str):
336
+ # Time-based specification (e.g., "4W", "30D", "3M")
337
+ if timestamps is None:
338
+ raise ValueError(
339
+ "Time-based size specifications require timestamps. "
340
+ "For pandas DataFrames: use a DatetimeIndex. "
341
+ "For Polars DataFrames: set timestamp_col='your_datetime_column'. "
342
+ "Example: PurgedWalkForwardCV(test_size='4W', timestamp_col='date')"
343
+ )
344
+
345
+ # Use calendar-aware parsing if calendar is configured
346
+ return parse_time_size_calendar_aware(
347
+ size_spec=size_spec,
348
+ timestamps=timestamps,
349
+ calendar=self.calendar,
350
+ )
351
+
352
+ elif isinstance(size_spec, float):
353
+ # Proportion of dataset
354
+ return int(n_samples * size_spec)
355
+ else:
356
+ # Integer sample count
357
+ return size_spec
358
+
359
+ def get_n_splits(
360
+ self,
361
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"] | None = None,
362
+ y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
363
+ groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
364
+ ) -> int:
365
+ """Get number of splits.
366
+
367
+ Parameters
368
+ ----------
369
+ X : array-like, optional
370
+ Always ignored, exists for compatibility.
371
+
372
+ y : array-like, optional
373
+ Always ignored, exists for compatibility.
374
+
375
+ groups : array-like, optional
376
+ Always ignored, exists for compatibility.
377
+
378
+ Returns:
379
+ -------
380
+ n_splits : int
381
+ Number of splits.
382
+ """
383
+ del X, y, groups # Unused, for sklearn compatibility
384
+ return self.n_splits
385
+
386
+ def split(
387
+ self,
388
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
389
+ y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
390
+ groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
391
+ ) -> Generator[tuple["NDArray[np.intp]", "NDArray[np.intp]"], None, None]:
392
+ """Generate train/test indices for walk-forward splits.
393
+
394
+ Parameters
395
+ ----------
396
+ X : array-like of shape (n_samples, n_features)
397
+ Training data.
398
+
399
+ y : array-like of shape (n_samples,), optional
400
+ Target variable.
401
+
402
+ groups : array-like of shape (n_samples,), optional
403
+ Group labels for samples.
404
+
405
+ Yields:
406
+ ------
407
+ train : ndarray
408
+ Training set indices for this split.
409
+
410
+ test : ndarray
411
+ Test set indices for this split.
412
+ """
413
+ # Validate inputs and get sample count
414
+ n_samples = self._validate_data(X, y, groups)
415
+
416
+ # Validate session alignment if enabled
417
+ self._validate_session_alignment(X, self.align_to_sessions, self.session_col)
418
+
419
+ # Branch between session-based and sample-based logic
420
+ if self.align_to_sessions:
421
+ # Session-aware splitting: operate on unique sessions
422
+ # X is verified to be a DataFrame by _validate_session_alignment
423
+ yield from self._split_by_sessions(
424
+ cast(pl.DataFrame | pd.DataFrame, X), y, groups, n_samples
425
+ )
426
+ else:
427
+ # Standard sample-based splitting
428
+ yield from self._split_by_samples(X, y, groups, n_samples)
429
+
430
+ def _split_by_samples(
431
+ self,
432
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
433
+ _y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None,
434
+ groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None,
435
+ n_samples: int,
436
+ ) -> Generator[tuple["NDArray[np.intp]", "NDArray[np.intp]"], None, None]:
437
+ """Generate splits using sample indices (original implementation)."""
438
+ # Extract timestamps if available (supports both Polars and pandas)
439
+ timestamps = self._extract_timestamps(X, self.timestamp_col)
440
+
441
+ # Calculate test size
442
+ if self.test_size is None:
443
+ test_size = n_samples // (self.n_splits + 1)
444
+ else:
445
+ test_size = self._parse_time_size(self.test_size, timestamps, n_samples)
446
+
447
+ # Calculate train size if specified
448
+ if self.train_size is not None:
449
+ train_size = self._parse_time_size(self.train_size, timestamps, n_samples)
450
+ else:
451
+ train_size = None
452
+
453
+ # Calculate split points
454
+ if self.consecutive:
455
+ # Consecutive walk-forward: back-to-back test periods with no gaps
456
+ # Useful for realistic trading simulation where test periods are sequential
457
+ step_size = test_size
458
+
459
+ # Determine where first test period starts
460
+ if train_size is not None and not self.expanding:
461
+ # Rolling window: first test comes after initial training window
462
+ first_test_start = train_size
463
+ elif self.expanding:
464
+ # Expanding window: ensure we have enough data for minimum train_size
465
+ # or default to test_size if train_size not specified
466
+ first_test_start = train_size if train_size is not None else test_size
467
+ else:
468
+ # No train_size specified and not expanding: start after first test-sized chunk
469
+ first_test_start = test_size
470
+
471
+ # Validate we have enough data for all consecutive periods
472
+ total_required = first_test_start + self.n_splits * test_size
473
+ if total_required > n_samples:
474
+ raise ValueError(
475
+ f"Insufficient data for consecutive={self.consecutive}: "
476
+ f"need {total_required:,} samples (first_test at {first_test_start:,} "
477
+ f"+ {self.n_splits} × {test_size:,}), but only have {n_samples:,}"
478
+ )
479
+ else:
480
+ # Spread folds across available data to sample different time periods
481
+ # Useful for testing robustness across different market regimes
482
+ available_for_splits = n_samples - test_size
483
+ step_size = available_for_splits // self.n_splits
484
+ first_test_start = test_size
485
+
486
+ for i in range(self.n_splits):
487
+ # Calculate test indices
488
+ test_start = first_test_start + i * step_size
489
+ test_end = min(test_start + test_size, n_samples)
490
+
491
+ # For the last split, optionally use all remaining data
492
+ # (only if test_size was not explicitly specified)
493
+ if i == self.n_splits - 1 and self.test_size is None:
494
+ test_end = n_samples
495
+
496
+ # Calculate train indices
497
+ if self.expanding:
498
+ # Expanding window: use all data from start
499
+ train_start = 0
500
+ else:
501
+ # Rolling window
502
+ if train_size is not None:
503
+ train_start = max(0, test_start - self.gap - train_size)
504
+ else:
505
+ # If no train_size specified, use all available data
506
+ train_start = 0
507
+
508
+ # Apply gap
509
+ train_end = test_start - self.gap
510
+
511
+ # Initial train indices (before purging/embargo)
512
+ train_indices = np.arange(train_start, train_end)
513
+
514
+ # Convert test boundaries to timestamps if needed
515
+ test_start_time, test_end_time = convert_indices_to_timestamps(
516
+ test_start,
517
+ test_end,
518
+ timestamps,
519
+ )
520
+
521
+ # Apply purging and embargo
522
+ clean_train_indices = apply_purging_and_embargo(
523
+ train_indices=train_indices,
524
+ test_start=test_start_time,
525
+ test_end=test_end_time,
526
+ label_horizon=self.label_horizon,
527
+ embargo_size=self.embargo_size,
528
+ embargo_pct=self.embargo_pct,
529
+ n_samples=n_samples,
530
+ timestamps=timestamps,
531
+ )
532
+
533
+ # Test indices
534
+ test_indices = np.arange(test_start, test_end, dtype=np.intp)
535
+
536
+ # Apply group isolation if requested
537
+ if self.isolate_groups and groups is not None:
538
+ clean_train_indices = isolate_groups_from_train(
539
+ clean_train_indices, test_indices, groups
540
+ )
541
+
542
+ yield clean_train_indices.astype(np.intp), test_indices
543
+
544
+ def _split_by_sessions(
545
+ self,
546
+ X: pl.DataFrame | pd.DataFrame,
547
+ _y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None,
548
+ groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None,
549
+ n_samples: int,
550
+ ) -> Generator[tuple["NDArray[np.intp]", "NDArray[np.intp]"], None, None]:
551
+ """Generate splits using session boundaries (session-aware)."""
552
+ # Get unique sessions in chronological order
553
+ unique_sessions = self._get_unique_sessions(X, self.session_col)
554
+ n_sessions = len(unique_sessions)
555
+
556
+ # Extract timestamps if available (for purging/embargo)
557
+ timestamps = self._extract_timestamps(X, self.timestamp_col)
558
+
559
+ # Calculate test size in sessions
560
+ if self.test_size is None:
561
+ test_size_sessions = n_sessions // (self.n_splits + 1)
562
+ elif isinstance(self.test_size, int):
563
+ # Integer test_size: interpret as number of sessions
564
+ test_size_sessions = self.test_size
565
+ elif isinstance(self.test_size, float):
566
+ # Float test_size: proportion of sessions
567
+ test_size_sessions = int(n_sessions * self.test_size)
568
+ else:
569
+ # Time-based test_size not supported with sessions
570
+ raise ValueError(
571
+ f"align_to_sessions=True does not support time-based test_size. "
572
+ f"Use integer (number of sessions) or float (proportion). Got: {self.test_size}"
573
+ )
574
+
575
+ # Calculate train size in sessions if specified
576
+ if self.train_size is not None:
577
+ if isinstance(self.train_size, int):
578
+ train_size_sessions = self.train_size
579
+ elif isinstance(self.train_size, float):
580
+ train_size_sessions = int(n_sessions * self.train_size)
581
+ else:
582
+ raise ValueError(
583
+ f"align_to_sessions=True does not support time-based train_size. "
584
+ f"Use integer (number of sessions) or float (proportion). Got: {self.train_size}"
585
+ )
586
+ else:
587
+ train_size_sessions = None
588
+
589
+ # Calculate split points in session space
590
+ if self.consecutive:
591
+ step_size_sessions = test_size_sessions
592
+
593
+ if train_size_sessions is not None and not self.expanding:
594
+ first_test_start_session = train_size_sessions
595
+ elif self.expanding:
596
+ first_test_start_session = (
597
+ train_size_sessions if train_size_sessions is not None else test_size_sessions
598
+ )
599
+ else:
600
+ first_test_start_session = test_size_sessions
601
+
602
+ total_required_sessions = first_test_start_session + self.n_splits * test_size_sessions
603
+ if total_required_sessions > n_sessions:
604
+ raise ValueError(
605
+ f"Insufficient sessions for consecutive={self.consecutive}: "
606
+ f"need {total_required_sessions:,} sessions (first_test at {first_test_start_session:,} "
607
+ f"+ {self.n_splits} × {test_size_sessions:,}), but only have {n_sessions:,}"
608
+ )
609
+ else:
610
+ available_for_splits_sessions = n_sessions - test_size_sessions
611
+ step_size_sessions = available_for_splits_sessions // self.n_splits
612
+ first_test_start_session = test_size_sessions
613
+
614
+ # Generate splits by mapping session ranges to row indices
615
+ for i in range(self.n_splits):
616
+ # Calculate test session range
617
+ test_start_session = first_test_start_session + i * step_size_sessions
618
+ test_end_session = min(test_start_session + test_size_sessions, n_sessions)
619
+
620
+ if i == self.n_splits - 1 and self.test_size is None:
621
+ test_end_session = n_sessions
622
+
623
+ # Calculate train session range
624
+ if self.expanding:
625
+ train_start_session = 0
626
+ else:
627
+ if train_size_sessions is not None:
628
+ train_start_session = max(
629
+ 0, test_start_session - self.gap - train_size_sessions
630
+ )
631
+ else:
632
+ train_start_session = 0
633
+
634
+ train_end_session = test_start_session - self.gap
635
+
636
+ # Get session IDs for train and test
637
+ if isinstance(unique_sessions, pl.Series):
638
+ train_sessions = unique_sessions[train_start_session:train_end_session].to_list()
639
+ test_sessions = unique_sessions[test_start_session:test_end_session].to_list()
640
+ session_col_values = X[self.session_col]
641
+ else: # pandas Series
642
+ train_sessions = unique_sessions.iloc[
643
+ train_start_session:train_end_session
644
+ ].tolist()
645
+ test_sessions = unique_sessions.iloc[test_start_session:test_end_session].tolist()
646
+ session_col_values = X[self.session_col]
647
+
648
+ # Map sessions to row indices
649
+ if isinstance(X, pl.DataFrame):
650
+ train_mask = session_col_values.is_in(train_sessions)
651
+ test_mask = session_col_values.is_in(test_sessions)
652
+ train_indices = np.where(train_mask.to_numpy())[0]
653
+ test_indices = np.where(test_mask.to_numpy())[0]
654
+ else: # pandas DataFrame
655
+ # Cast to pd.Series since X is pd.DataFrame here
656
+ session_col_pd = cast(pd.Series, session_col_values)
657
+ train_mask = session_col_pd.isin(train_sessions)
658
+ test_mask = session_col_pd.isin(test_sessions)
659
+ train_indices = np.where(train_mask.to_numpy())[0]
660
+ test_indices = np.where(test_mask.to_numpy())[0]
661
+
662
+ # Apply purging and embargo if configured
663
+ if self._has_purging_or_embargo():
664
+ # Compute actual timestamp bounds from test indices
665
+ # This is critical for multi-asset data where rows may be sorted by
666
+ # asset rather than time - using positional indices [0] and [-1] would
667
+ # give incorrect timestamp bounds
668
+ test_start_time, test_end_time = self._timestamp_window_from_indices(
669
+ test_indices, timestamps
670
+ )
671
+
672
+ clean_train_indices = apply_purging_and_embargo(
673
+ train_indices=train_indices,
674
+ test_start=test_start_time,
675
+ test_end=test_end_time,
676
+ label_horizon=self.label_horizon,
677
+ embargo_size=self.embargo_size,
678
+ embargo_pct=self.embargo_pct,
679
+ n_samples=n_samples,
680
+ timestamps=timestamps,
681
+ )
682
+ else:
683
+ clean_train_indices = train_indices
684
+
685
+ # Apply group isolation if requested
686
+ if self.isolate_groups and groups is not None:
687
+ clean_train_indices = isolate_groups_from_train(
688
+ clean_train_indices, test_indices, groups
689
+ )
690
+
691
+ yield clean_train_indices.astype(np.intp), test_indices.astype(np.intp)
692
+
693
+ def _has_purging_or_embargo(self) -> bool:
694
+ """Check if purging or embargo is needed.
695
+
696
+ Handles both int and pd.Timedelta values for label_horizon and embargo_size.
697
+
698
+ Returns
699
+ -------
700
+ bool
701
+ True if purging or embargo should be applied.
702
+ """
703
+ # Check label_horizon (can be int or Timedelta)
704
+ has_label_horizon = False
705
+ if isinstance(self.label_horizon, int | float):
706
+ has_label_horizon = self.label_horizon > 0
707
+ elif hasattr(self.label_horizon, "total_seconds"): # pd.Timedelta
708
+ has_label_horizon = self.label_horizon.total_seconds() > 0
709
+
710
+ # Check embargo (embargo_size can be int or Timedelta, embargo_pct is always float or None)
711
+ has_embargo = self.embargo_size is not None or self.embargo_pct is not None
712
+
713
+ return has_label_horizon or has_embargo
714
+
715
+ @staticmethod
716
+ def _timestamp_window_from_indices(
717
+ indices: "NDArray[np.intp]",
718
+ timestamps: pd.DatetimeIndex | None,
719
+ ) -> tuple[int | pd.Timestamp, int | pd.Timestamp]:
720
+ """Compute timestamp window from actual indices (for session-aligned purging).
721
+
722
+ This is critical for correct purging in session-aligned mode. Instead of
723
+ using positional indices [0] and [-1] which assume chronological ordering,
724
+ we compute the actual timestamp bounds from all test indices.
725
+
726
+ For multi-asset data where rows may be sorted by asset rather than time,
727
+ test_indices[0] may not have the minimum timestamp.
728
+
729
+ Parameters
730
+ ----------
731
+ indices : ndarray
732
+ Row indices of test samples.
733
+ timestamps : pd.DatetimeIndex or None
734
+ Timestamps for all samples. If None, returns index bounds.
735
+
736
+ Returns
737
+ -------
738
+ start_time : int or pd.Timestamp
739
+ Minimum timestamp of test indices (or min index if no timestamps).
740
+ end_time_exclusive : int or pd.Timestamp
741
+ Maximum timestamp + 1 nanosecond (or max index + 1 if no timestamps).
742
+ """
743
+ if len(indices) == 0:
744
+ # Empty indices - return minimal bounds
745
+ if timestamps is None:
746
+ return 0, 0
747
+ return timestamps[0], timestamps[0]
748
+
749
+ if timestamps is None:
750
+ # No timestamps - return index bounds
751
+ return int(indices.min()), int(indices.max()) + 1
752
+
753
+ test_timestamps = timestamps.take(indices)
754
+ start_time = test_timestamps.min()
755
+ # Add 1 nanosecond to make end exclusive (handles duplicate timestamps)
756
+ end_time_exclusive = test_timestamps.max() + pd.Timedelta(1, "ns")
757
+ return start_time, end_time_exclusive