ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,757 @@
|
|
|
1
|
+
"""Walk-forward cross-validation with purging and embargo.
|
|
2
|
+
|
|
3
|
+
This module implements walk-forward cross-validation that prevents data leakage
|
|
4
|
+
through purging and embargo, suitable for time-series financial data.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from collections.abc import Generator
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Union, cast
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import polars as pl
|
|
13
|
+
|
|
14
|
+
from ml4t.diagnostic.core.purging import apply_purging_and_embargo
|
|
15
|
+
from ml4t.diagnostic.splitters.base import BaseSplitter
|
|
16
|
+
from ml4t.diagnostic.splitters.calendar import TradingCalendar, parse_time_size_calendar_aware
|
|
17
|
+
from ml4t.diagnostic.splitters.calendar_config import CalendarConfig
|
|
18
|
+
from ml4t.diagnostic.splitters.config import PurgedWalkForwardConfig
|
|
19
|
+
from ml4t.diagnostic.splitters.group_isolation import isolate_groups_from_train
|
|
20
|
+
from ml4t.diagnostic.splitters.utils import convert_indices_to_timestamps
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from numpy.typing import NDArray
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PurgedWalkForwardCV(BaseSplitter):
|
|
27
|
+
"""Walk-forward cross-validator with purging and embargo.
|
|
28
|
+
|
|
29
|
+
Walk-forward CV creates sequential train/test splits where training data
|
|
30
|
+
always precedes test data. This implementation adds purging and embargo
|
|
31
|
+
to prevent data leakage from label overlap and serial correlation.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
n_splits : int, default=5
|
|
36
|
+
Number of splits to generate.
|
|
37
|
+
|
|
38
|
+
test_size : int, float, str, or None, optional
|
|
39
|
+
Size of each test set:
|
|
40
|
+
- If int: number of samples (e.g., 1000)
|
|
41
|
+
- If float: proportion of dataset (e.g., 0.1)
|
|
42
|
+
- If str: time period using pandas offset aliases (e.g., "4W", "30D", "3M")
|
|
43
|
+
- If None: uses 1 / (n_splits + 1)
|
|
44
|
+
Time-based specifications require X to have a DatetimeIndex.
|
|
45
|
+
|
|
46
|
+
train_size : int, float, str, or None, optional
|
|
47
|
+
Size of each training set:
|
|
48
|
+
- If int: number of samples (e.g., 10000)
|
|
49
|
+
- If float: proportion of dataset (e.g., 0.5)
|
|
50
|
+
- If str: time period using pandas offset aliases (e.g., "78W", "6M", "2Y")
|
|
51
|
+
- If None: uses all available data before test set
|
|
52
|
+
Time-based specifications require X to have a DatetimeIndex.
|
|
53
|
+
|
|
54
|
+
gap : int, default=0
|
|
55
|
+
Gap between training and test set (in addition to purging).
|
|
56
|
+
|
|
57
|
+
label_horizon : int or pd.Timedelta, default=0
|
|
58
|
+
Forward-looking period of labels for purging calculation.
|
|
59
|
+
|
|
60
|
+
embargo_size : int or pd.Timedelta, optional
|
|
61
|
+
Size of embargo period after each test set.
|
|
62
|
+
|
|
63
|
+
embargo_pct : float, optional
|
|
64
|
+
Embargo size as percentage of total samples.
|
|
65
|
+
|
|
66
|
+
expanding : bool, default=True
|
|
67
|
+
If True, training window expands with each split.
|
|
68
|
+
If False, uses fixed-size rolling window.
|
|
69
|
+
|
|
70
|
+
consecutive : bool, default=False
|
|
71
|
+
If True, uses consecutive (back-to-back) test periods with no gaps.
|
|
72
|
+
This is appropriate for walk-forward validation where you want to
|
|
73
|
+
simulate realistic trading with sequential validation periods.
|
|
74
|
+
If False, spreads test periods across the dataset to sample different
|
|
75
|
+
time periods (useful for testing robustness across market regimes).
|
|
76
|
+
|
|
77
|
+
calendar : str, CalendarConfig, or TradingCalendar, optional
|
|
78
|
+
Trading calendar for calendar-aware time period calculations.
|
|
79
|
+
- If str: Name of pandas_market_calendars calendar (e.g., 'CME_Equity', 'NYSE')
|
|
80
|
+
Creates default CalendarConfig with UTC timezone
|
|
81
|
+
- If CalendarConfig: Full configuration with exchange, timezone, and options
|
|
82
|
+
- If TradingCalendar: Pre-configured calendar instance
|
|
83
|
+
- If None: Uses naive time-based calculation (backward compatible)
|
|
84
|
+
|
|
85
|
+
For intraday data with time-based test_size/train_size (e.g., '4W'),
|
|
86
|
+
using a calendar ensures proper session-aware splitting:
|
|
87
|
+
- Trading sessions are atomic units (won't split Sunday 5pm - Friday 4pm)
|
|
88
|
+
- Handles varying data density in activity-based data (dollar bars, trade bars)
|
|
89
|
+
- Proper timezone handling for tz-naive and tz-aware data
|
|
90
|
+
- '1D' selections: Complete trading sessions
|
|
91
|
+
- '4W' selections: Complete trading weeks (e.g., 4 weeks of 5 sessions each)
|
|
92
|
+
|
|
93
|
+
Examples:
|
|
94
|
+
>>> from ml4t.diagnostic.splitters.calendar_config import CME_CONFIG
|
|
95
|
+
>>> cv = PurgedWalkForwardCV(test_size='4W', calendar=CME_CONFIG) # CME futures
|
|
96
|
+
>>> cv = PurgedWalkForwardCV(test_size='1W', calendar='NYSE') # US equities (simple)
|
|
97
|
+
|
|
98
|
+
align_to_sessions : bool, default=False
|
|
99
|
+
If True, align fold boundaries to trading session boundaries.
|
|
100
|
+
Requires X to have a session column (specified by session_col parameter).
|
|
101
|
+
|
|
102
|
+
Trading sessions should be assigned using the qdata library before cross-validation:
|
|
103
|
+
- Use DataManager with exchange/calendar parameters, or
|
|
104
|
+
- Use SessionAssigner.from_exchange('CME') directly
|
|
105
|
+
|
|
106
|
+
When enabled, fold boundaries will never split a trading session, preventing
|
|
107
|
+
subtle lookahead bias in intraday strategies.
|
|
108
|
+
|
|
109
|
+
session_col : str, default='session_date'
|
|
110
|
+
Name of the column containing session identifiers.
|
|
111
|
+
Only used if align_to_sessions=True.
|
|
112
|
+
This column should be added by qdata.sessions.SessionAssigner
|
|
113
|
+
|
|
114
|
+
isolate_groups : bool, default=False
|
|
115
|
+
If True, prevent the same group (asset/symbol) from appearing in both
|
|
116
|
+
train and test sets. This is critical for multi-asset validation to
|
|
117
|
+
avoid data leakage.
|
|
118
|
+
|
|
119
|
+
Requires passing `groups` parameter to split() method with asset IDs.
|
|
120
|
+
|
|
121
|
+
Example:
|
|
122
|
+
>>> cv = PurgedWalkForwardCV(n_splits=5, isolate_groups=True)
|
|
123
|
+
>>> for train, test in cv.split(df, groups=df['symbol']):
|
|
124
|
+
... # train and test will have completely different symbols
|
|
125
|
+
... pass
|
|
126
|
+
|
|
127
|
+
Attributes:
|
|
128
|
+
----------
|
|
129
|
+
n_splits_ : int
|
|
130
|
+
The number of splits.
|
|
131
|
+
|
|
132
|
+
Examples:
|
|
133
|
+
--------
|
|
134
|
+
>>> import numpy as np
|
|
135
|
+
>>> from ml4t.diagnostic.splitters import PurgedWalkForwardCV
|
|
136
|
+
>>> X = np.arange(100).reshape(100, 1)
|
|
137
|
+
>>> cv = PurgedWalkForwardCV(n_splits=3, label_horizon=5, embargo_size=2)
|
|
138
|
+
>>> for train, test in cv.split(X):
|
|
139
|
+
... print(f"Train: {len(train)}, Test: {len(test)}")
|
|
140
|
+
Train: 17, Test: 25
|
|
141
|
+
Train: 40, Test: 25
|
|
142
|
+
Train: 63, Test: 25
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def __init__(
|
|
146
|
+
self,
|
|
147
|
+
config: PurgedWalkForwardConfig | None = None,
|
|
148
|
+
*,
|
|
149
|
+
n_splits: int = 5,
|
|
150
|
+
test_size: float | None = None,
|
|
151
|
+
train_size: float | None = None,
|
|
152
|
+
gap: int = 0,
|
|
153
|
+
label_horizon: int | pd.Timedelta = 0,
|
|
154
|
+
embargo_size: int | pd.Timedelta | None = None,
|
|
155
|
+
embargo_pct: float | None = None,
|
|
156
|
+
expanding: bool = True,
|
|
157
|
+
consecutive: bool = False,
|
|
158
|
+
calendar: str | CalendarConfig | TradingCalendar | None = None,
|
|
159
|
+
align_to_sessions: bool = False,
|
|
160
|
+
session_col: str = "session_date",
|
|
161
|
+
timestamp_col: str | None = None,
|
|
162
|
+
isolate_groups: bool = False,
|
|
163
|
+
) -> None:
|
|
164
|
+
"""Initialize PurgedWalkForwardCV.
|
|
165
|
+
|
|
166
|
+
This splitter uses a config-first architecture. You can either:
|
|
167
|
+
1. Pass a config object: PurgedWalkForwardCV(config=my_config)
|
|
168
|
+
2. Pass individual parameters: PurgedWalkForwardCV(n_splits=5, test_size=100)
|
|
169
|
+
|
|
170
|
+
Parameters are automatically converted to a config object internally,
|
|
171
|
+
ensuring a single source of truth for all validation and logic.
|
|
172
|
+
|
|
173
|
+
Examples
|
|
174
|
+
--------
|
|
175
|
+
>>> # Approach 1: Direct parameters (convenient)
|
|
176
|
+
>>> cv = PurgedWalkForwardCV(n_splits=5, test_size=100)
|
|
177
|
+
>>>
|
|
178
|
+
>>> # Approach 2: Config object (for serialization/reproducibility)
|
|
179
|
+
>>> from ml4t.diagnostic.splitters.config import PurgedWalkForwardConfig
|
|
180
|
+
>>> config = PurgedWalkForwardConfig(n_splits=5, test_size=100)
|
|
181
|
+
>>> cv = PurgedWalkForwardCV(config=config)
|
|
182
|
+
>>>
|
|
183
|
+
>>> # Config can be serialized
|
|
184
|
+
>>> config.to_json("cv_config.json")
|
|
185
|
+
>>> loaded = PurgedWalkForwardConfig.from_json("cv_config.json")
|
|
186
|
+
>>> cv = PurgedWalkForwardCV(config=loaded)
|
|
187
|
+
"""
|
|
188
|
+
# Config-first: either use provided config or create from params
|
|
189
|
+
if config is not None:
|
|
190
|
+
# Explicit config provided
|
|
191
|
+
# Verify no conflicting parameters were passed
|
|
192
|
+
non_default_params = []
|
|
193
|
+
if n_splits != 5:
|
|
194
|
+
non_default_params.append("n_splits")
|
|
195
|
+
if test_size is not None:
|
|
196
|
+
non_default_params.append("test_size")
|
|
197
|
+
if train_size is not None:
|
|
198
|
+
non_default_params.append("train_size")
|
|
199
|
+
if gap != 0:
|
|
200
|
+
non_default_params.append("gap")
|
|
201
|
+
if label_horizon != 0:
|
|
202
|
+
non_default_params.append("label_horizon")
|
|
203
|
+
if embargo_size is not None:
|
|
204
|
+
non_default_params.append("embargo_size")
|
|
205
|
+
if embargo_pct is not None:
|
|
206
|
+
non_default_params.append("embargo_pct")
|
|
207
|
+
if not expanding:
|
|
208
|
+
non_default_params.append("expanding")
|
|
209
|
+
if consecutive:
|
|
210
|
+
non_default_params.append("consecutive")
|
|
211
|
+
if calendar is not None:
|
|
212
|
+
non_default_params.append("calendar")
|
|
213
|
+
if align_to_sessions:
|
|
214
|
+
non_default_params.append("align_to_sessions")
|
|
215
|
+
if session_col != "session_date":
|
|
216
|
+
non_default_params.append("session_col")
|
|
217
|
+
if timestamp_col is not None:
|
|
218
|
+
non_default_params.append("timestamp_col")
|
|
219
|
+
if isolate_groups:
|
|
220
|
+
non_default_params.append("isolate_groups")
|
|
221
|
+
|
|
222
|
+
if non_default_params:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
f"Cannot specify both 'config' and individual parameters. "
|
|
225
|
+
f"Got config plus: {', '.join(non_default_params)}"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
self.config = config
|
|
229
|
+
else:
|
|
230
|
+
# Create config from individual parameters
|
|
231
|
+
# Note: embargo_size maps to embargo_td in config
|
|
232
|
+
self.config = PurgedWalkForwardConfig(
|
|
233
|
+
n_splits=n_splits,
|
|
234
|
+
test_size=test_size,
|
|
235
|
+
train_size=train_size,
|
|
236
|
+
label_horizon=label_horizon,
|
|
237
|
+
embargo_td=embargo_size,
|
|
238
|
+
align_to_sessions=align_to_sessions,
|
|
239
|
+
session_col=session_col,
|
|
240
|
+
timestamp_col=timestamp_col,
|
|
241
|
+
isolate_groups=isolate_groups,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# Handle calendar initialization
|
|
245
|
+
# NOTE: Calendar config could be moved to WalkForwardConfig in future version
|
|
246
|
+
if calendar is None:
|
|
247
|
+
self.calendar = None
|
|
248
|
+
elif isinstance(calendar, str | CalendarConfig):
|
|
249
|
+
self.calendar = TradingCalendar(calendar)
|
|
250
|
+
elif isinstance(calendar, TradingCalendar):
|
|
251
|
+
self.calendar = calendar
|
|
252
|
+
else:
|
|
253
|
+
raise TypeError(
|
|
254
|
+
f"calendar must be str, CalendarConfig, TradingCalendar, or None, got {type(calendar)}"
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Legacy attributes for compatibility with existing split() implementation
|
|
258
|
+
# These reference the config values
|
|
259
|
+
self.gap = gap
|
|
260
|
+
self.embargo_pct = embargo_pct
|
|
261
|
+
self.expanding = expanding
|
|
262
|
+
self.consecutive = consecutive
|
|
263
|
+
|
|
264
|
+
# Property accessors for config values (clean API)
|
|
265
|
+
@property
|
|
266
|
+
def n_splits(self) -> int:
|
|
267
|
+
"""Number of cross-validation folds."""
|
|
268
|
+
return self.config.n_splits
|
|
269
|
+
|
|
270
|
+
@property
|
|
271
|
+
def test_size(self) -> int | float | str | None:
|
|
272
|
+
"""Test set size specification."""
|
|
273
|
+
return self.config.test_size
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def train_size(self) -> int | float | str | None:
|
|
277
|
+
"""Training set size specification."""
|
|
278
|
+
return self.config.train_size
|
|
279
|
+
|
|
280
|
+
@property
|
|
281
|
+
def label_horizon(self) -> int:
|
|
282
|
+
"""Forward-looking period of labels."""
|
|
283
|
+
return self.config.label_horizon
|
|
284
|
+
|
|
285
|
+
@property
|
|
286
|
+
def embargo_size(self) -> int | None:
|
|
287
|
+
"""Embargo buffer size."""
|
|
288
|
+
return self.config.embargo_td
|
|
289
|
+
|
|
290
|
+
@property
|
|
291
|
+
def align_to_sessions(self) -> bool:
|
|
292
|
+
"""Whether to align fold boundaries to sessions."""
|
|
293
|
+
return self.config.align_to_sessions
|
|
294
|
+
|
|
295
|
+
@property
|
|
296
|
+
def session_col(self) -> str:
|
|
297
|
+
"""Column name containing session identifiers."""
|
|
298
|
+
return self.config.session_col
|
|
299
|
+
|
|
300
|
+
@property
|
|
301
|
+
def timestamp_col(self) -> str | None:
|
|
302
|
+
"""Column name containing timestamps for time-based sizes."""
|
|
303
|
+
return self.config.timestamp_col
|
|
304
|
+
|
|
305
|
+
@property
|
|
306
|
+
def isolate_groups(self) -> bool:
|
|
307
|
+
"""Whether to prevent group overlap between train/test."""
|
|
308
|
+
return self.config.isolate_groups
|
|
309
|
+
|
|
310
|
+
def _parse_time_size(
|
|
311
|
+
self,
|
|
312
|
+
size_spec: int | float | str,
|
|
313
|
+
timestamps: pd.DatetimeIndex | None,
|
|
314
|
+
n_samples: int,
|
|
315
|
+
) -> int:
|
|
316
|
+
"""Parse size specification and convert to sample count.
|
|
317
|
+
|
|
318
|
+
Uses calendar-aware logic if calendar is configured, otherwise falls back
|
|
319
|
+
to naive time-based calculation.
|
|
320
|
+
|
|
321
|
+
Parameters
|
|
322
|
+
----------
|
|
323
|
+
size_spec : int, float, or str
|
|
324
|
+
Size specification to parse.
|
|
325
|
+
timestamps : pd.DatetimeIndex
|
|
326
|
+
Datetime index of the data.
|
|
327
|
+
n_samples : int
|
|
328
|
+
Total number of samples in dataset.
|
|
329
|
+
|
|
330
|
+
Returns
|
|
331
|
+
-------
|
|
332
|
+
int
|
|
333
|
+
Number of samples corresponding to the size specification.
|
|
334
|
+
"""
|
|
335
|
+
if isinstance(size_spec, str):
|
|
336
|
+
# Time-based specification (e.g., "4W", "30D", "3M")
|
|
337
|
+
if timestamps is None:
|
|
338
|
+
raise ValueError(
|
|
339
|
+
"Time-based size specifications require timestamps. "
|
|
340
|
+
"For pandas DataFrames: use a DatetimeIndex. "
|
|
341
|
+
"For Polars DataFrames: set timestamp_col='your_datetime_column'. "
|
|
342
|
+
"Example: PurgedWalkForwardCV(test_size='4W', timestamp_col='date')"
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Use calendar-aware parsing if calendar is configured
|
|
346
|
+
return parse_time_size_calendar_aware(
|
|
347
|
+
size_spec=size_spec,
|
|
348
|
+
timestamps=timestamps,
|
|
349
|
+
calendar=self.calendar,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
elif isinstance(size_spec, float):
|
|
353
|
+
# Proportion of dataset
|
|
354
|
+
return int(n_samples * size_spec)
|
|
355
|
+
else:
|
|
356
|
+
# Integer sample count
|
|
357
|
+
return size_spec
|
|
358
|
+
|
|
359
|
+
def get_n_splits(
|
|
360
|
+
self,
|
|
361
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"] | None = None,
|
|
362
|
+
y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
|
|
363
|
+
groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
|
|
364
|
+
) -> int:
|
|
365
|
+
"""Get number of splits.
|
|
366
|
+
|
|
367
|
+
Parameters
|
|
368
|
+
----------
|
|
369
|
+
X : array-like, optional
|
|
370
|
+
Always ignored, exists for compatibility.
|
|
371
|
+
|
|
372
|
+
y : array-like, optional
|
|
373
|
+
Always ignored, exists for compatibility.
|
|
374
|
+
|
|
375
|
+
groups : array-like, optional
|
|
376
|
+
Always ignored, exists for compatibility.
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
-------
|
|
380
|
+
n_splits : int
|
|
381
|
+
Number of splits.
|
|
382
|
+
"""
|
|
383
|
+
del X, y, groups # Unused, for sklearn compatibility
|
|
384
|
+
return self.n_splits
|
|
385
|
+
|
|
386
|
+
def split(
|
|
387
|
+
self,
|
|
388
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
389
|
+
y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
|
|
390
|
+
groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
|
|
391
|
+
) -> Generator[tuple["NDArray[np.intp]", "NDArray[np.intp]"], None, None]:
|
|
392
|
+
"""Generate train/test indices for walk-forward splits.
|
|
393
|
+
|
|
394
|
+
Parameters
|
|
395
|
+
----------
|
|
396
|
+
X : array-like of shape (n_samples, n_features)
|
|
397
|
+
Training data.
|
|
398
|
+
|
|
399
|
+
y : array-like of shape (n_samples,), optional
|
|
400
|
+
Target variable.
|
|
401
|
+
|
|
402
|
+
groups : array-like of shape (n_samples,), optional
|
|
403
|
+
Group labels for samples.
|
|
404
|
+
|
|
405
|
+
Yields:
|
|
406
|
+
------
|
|
407
|
+
train : ndarray
|
|
408
|
+
Training set indices for this split.
|
|
409
|
+
|
|
410
|
+
test : ndarray
|
|
411
|
+
Test set indices for this split.
|
|
412
|
+
"""
|
|
413
|
+
# Validate inputs and get sample count
|
|
414
|
+
n_samples = self._validate_data(X, y, groups)
|
|
415
|
+
|
|
416
|
+
# Validate session alignment if enabled
|
|
417
|
+
self._validate_session_alignment(X, self.align_to_sessions, self.session_col)
|
|
418
|
+
|
|
419
|
+
# Branch between session-based and sample-based logic
|
|
420
|
+
if self.align_to_sessions:
|
|
421
|
+
# Session-aware splitting: operate on unique sessions
|
|
422
|
+
# X is verified to be a DataFrame by _validate_session_alignment
|
|
423
|
+
yield from self._split_by_sessions(
|
|
424
|
+
cast(pl.DataFrame | pd.DataFrame, X), y, groups, n_samples
|
|
425
|
+
)
|
|
426
|
+
else:
|
|
427
|
+
# Standard sample-based splitting
|
|
428
|
+
yield from self._split_by_samples(X, y, groups, n_samples)
|
|
429
|
+
|
|
430
|
+
def _split_by_samples(
|
|
431
|
+
self,
|
|
432
|
+
X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
|
|
433
|
+
_y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None,
|
|
434
|
+
groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None,
|
|
435
|
+
n_samples: int,
|
|
436
|
+
) -> Generator[tuple["NDArray[np.intp]", "NDArray[np.intp]"], None, None]:
|
|
437
|
+
"""Generate splits using sample indices (original implementation)."""
|
|
438
|
+
# Extract timestamps if available (supports both Polars and pandas)
|
|
439
|
+
timestamps = self._extract_timestamps(X, self.timestamp_col)
|
|
440
|
+
|
|
441
|
+
# Calculate test size
|
|
442
|
+
if self.test_size is None:
|
|
443
|
+
test_size = n_samples // (self.n_splits + 1)
|
|
444
|
+
else:
|
|
445
|
+
test_size = self._parse_time_size(self.test_size, timestamps, n_samples)
|
|
446
|
+
|
|
447
|
+
# Calculate train size if specified
|
|
448
|
+
if self.train_size is not None:
|
|
449
|
+
train_size = self._parse_time_size(self.train_size, timestamps, n_samples)
|
|
450
|
+
else:
|
|
451
|
+
train_size = None
|
|
452
|
+
|
|
453
|
+
# Calculate split points
|
|
454
|
+
if self.consecutive:
|
|
455
|
+
# Consecutive walk-forward: back-to-back test periods with no gaps
|
|
456
|
+
# Useful for realistic trading simulation where test periods are sequential
|
|
457
|
+
step_size = test_size
|
|
458
|
+
|
|
459
|
+
# Determine where first test period starts
|
|
460
|
+
if train_size is not None and not self.expanding:
|
|
461
|
+
# Rolling window: first test comes after initial training window
|
|
462
|
+
first_test_start = train_size
|
|
463
|
+
elif self.expanding:
|
|
464
|
+
# Expanding window: ensure we have enough data for minimum train_size
|
|
465
|
+
# or default to test_size if train_size not specified
|
|
466
|
+
first_test_start = train_size if train_size is not None else test_size
|
|
467
|
+
else:
|
|
468
|
+
# No train_size specified and not expanding: start after first test-sized chunk
|
|
469
|
+
first_test_start = test_size
|
|
470
|
+
|
|
471
|
+
# Validate we have enough data for all consecutive periods
|
|
472
|
+
total_required = first_test_start + self.n_splits * test_size
|
|
473
|
+
if total_required > n_samples:
|
|
474
|
+
raise ValueError(
|
|
475
|
+
f"Insufficient data for consecutive={self.consecutive}: "
|
|
476
|
+
f"need {total_required:,} samples (first_test at {first_test_start:,} "
|
|
477
|
+
f"+ {self.n_splits} × {test_size:,}), but only have {n_samples:,}"
|
|
478
|
+
)
|
|
479
|
+
else:
|
|
480
|
+
# Spread folds across available data to sample different time periods
|
|
481
|
+
# Useful for testing robustness across different market regimes
|
|
482
|
+
available_for_splits = n_samples - test_size
|
|
483
|
+
step_size = available_for_splits // self.n_splits
|
|
484
|
+
first_test_start = test_size
|
|
485
|
+
|
|
486
|
+
for i in range(self.n_splits):
|
|
487
|
+
# Calculate test indices
|
|
488
|
+
test_start = first_test_start + i * step_size
|
|
489
|
+
test_end = min(test_start + test_size, n_samples)
|
|
490
|
+
|
|
491
|
+
# For the last split, optionally use all remaining data
|
|
492
|
+
# (only if test_size was not explicitly specified)
|
|
493
|
+
if i == self.n_splits - 1 and self.test_size is None:
|
|
494
|
+
test_end = n_samples
|
|
495
|
+
|
|
496
|
+
# Calculate train indices
|
|
497
|
+
if self.expanding:
|
|
498
|
+
# Expanding window: use all data from start
|
|
499
|
+
train_start = 0
|
|
500
|
+
else:
|
|
501
|
+
# Rolling window
|
|
502
|
+
if train_size is not None:
|
|
503
|
+
train_start = max(0, test_start - self.gap - train_size)
|
|
504
|
+
else:
|
|
505
|
+
# If no train_size specified, use all available data
|
|
506
|
+
train_start = 0
|
|
507
|
+
|
|
508
|
+
# Apply gap
|
|
509
|
+
train_end = test_start - self.gap
|
|
510
|
+
|
|
511
|
+
# Initial train indices (before purging/embargo)
|
|
512
|
+
train_indices = np.arange(train_start, train_end)
|
|
513
|
+
|
|
514
|
+
# Convert test boundaries to timestamps if needed
|
|
515
|
+
test_start_time, test_end_time = convert_indices_to_timestamps(
|
|
516
|
+
test_start,
|
|
517
|
+
test_end,
|
|
518
|
+
timestamps,
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
# Apply purging and embargo
|
|
522
|
+
clean_train_indices = apply_purging_and_embargo(
|
|
523
|
+
train_indices=train_indices,
|
|
524
|
+
test_start=test_start_time,
|
|
525
|
+
test_end=test_end_time,
|
|
526
|
+
label_horizon=self.label_horizon,
|
|
527
|
+
embargo_size=self.embargo_size,
|
|
528
|
+
embargo_pct=self.embargo_pct,
|
|
529
|
+
n_samples=n_samples,
|
|
530
|
+
timestamps=timestamps,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
# Test indices
|
|
534
|
+
test_indices = np.arange(test_start, test_end, dtype=np.intp)
|
|
535
|
+
|
|
536
|
+
# Apply group isolation if requested
|
|
537
|
+
if self.isolate_groups and groups is not None:
|
|
538
|
+
clean_train_indices = isolate_groups_from_train(
|
|
539
|
+
clean_train_indices, test_indices, groups
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
yield clean_train_indices.astype(np.intp), test_indices
|
|
543
|
+
|
|
544
|
+
def _split_by_sessions(
|
|
545
|
+
self,
|
|
546
|
+
X: pl.DataFrame | pd.DataFrame,
|
|
547
|
+
_y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None,
|
|
548
|
+
groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None,
|
|
549
|
+
n_samples: int,
|
|
550
|
+
) -> Generator[tuple["NDArray[np.intp]", "NDArray[np.intp]"], None, None]:
|
|
551
|
+
"""Generate splits using session boundaries (session-aware)."""
|
|
552
|
+
# Get unique sessions in chronological order
|
|
553
|
+
unique_sessions = self._get_unique_sessions(X, self.session_col)
|
|
554
|
+
n_sessions = len(unique_sessions)
|
|
555
|
+
|
|
556
|
+
# Extract timestamps if available (for purging/embargo)
|
|
557
|
+
timestamps = self._extract_timestamps(X, self.timestamp_col)
|
|
558
|
+
|
|
559
|
+
# Calculate test size in sessions
|
|
560
|
+
if self.test_size is None:
|
|
561
|
+
test_size_sessions = n_sessions // (self.n_splits + 1)
|
|
562
|
+
elif isinstance(self.test_size, int):
|
|
563
|
+
# Integer test_size: interpret as number of sessions
|
|
564
|
+
test_size_sessions = self.test_size
|
|
565
|
+
elif isinstance(self.test_size, float):
|
|
566
|
+
# Float test_size: proportion of sessions
|
|
567
|
+
test_size_sessions = int(n_sessions * self.test_size)
|
|
568
|
+
else:
|
|
569
|
+
# Time-based test_size not supported with sessions
|
|
570
|
+
raise ValueError(
|
|
571
|
+
f"align_to_sessions=True does not support time-based test_size. "
|
|
572
|
+
f"Use integer (number of sessions) or float (proportion). Got: {self.test_size}"
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
# Calculate train size in sessions if specified
|
|
576
|
+
if self.train_size is not None:
|
|
577
|
+
if isinstance(self.train_size, int):
|
|
578
|
+
train_size_sessions = self.train_size
|
|
579
|
+
elif isinstance(self.train_size, float):
|
|
580
|
+
train_size_sessions = int(n_sessions * self.train_size)
|
|
581
|
+
else:
|
|
582
|
+
raise ValueError(
|
|
583
|
+
f"align_to_sessions=True does not support time-based train_size. "
|
|
584
|
+
f"Use integer (number of sessions) or float (proportion). Got: {self.train_size}"
|
|
585
|
+
)
|
|
586
|
+
else:
|
|
587
|
+
train_size_sessions = None
|
|
588
|
+
|
|
589
|
+
# Calculate split points in session space
|
|
590
|
+
if self.consecutive:
|
|
591
|
+
step_size_sessions = test_size_sessions
|
|
592
|
+
|
|
593
|
+
if train_size_sessions is not None and not self.expanding:
|
|
594
|
+
first_test_start_session = train_size_sessions
|
|
595
|
+
elif self.expanding:
|
|
596
|
+
first_test_start_session = (
|
|
597
|
+
train_size_sessions if train_size_sessions is not None else test_size_sessions
|
|
598
|
+
)
|
|
599
|
+
else:
|
|
600
|
+
first_test_start_session = test_size_sessions
|
|
601
|
+
|
|
602
|
+
total_required_sessions = first_test_start_session + self.n_splits * test_size_sessions
|
|
603
|
+
if total_required_sessions > n_sessions:
|
|
604
|
+
raise ValueError(
|
|
605
|
+
f"Insufficient sessions for consecutive={self.consecutive}: "
|
|
606
|
+
f"need {total_required_sessions:,} sessions (first_test at {first_test_start_session:,} "
|
|
607
|
+
f"+ {self.n_splits} × {test_size_sessions:,}), but only have {n_sessions:,}"
|
|
608
|
+
)
|
|
609
|
+
else:
|
|
610
|
+
available_for_splits_sessions = n_sessions - test_size_sessions
|
|
611
|
+
step_size_sessions = available_for_splits_sessions // self.n_splits
|
|
612
|
+
first_test_start_session = test_size_sessions
|
|
613
|
+
|
|
614
|
+
# Generate splits by mapping session ranges to row indices
|
|
615
|
+
for i in range(self.n_splits):
|
|
616
|
+
# Calculate test session range
|
|
617
|
+
test_start_session = first_test_start_session + i * step_size_sessions
|
|
618
|
+
test_end_session = min(test_start_session + test_size_sessions, n_sessions)
|
|
619
|
+
|
|
620
|
+
if i == self.n_splits - 1 and self.test_size is None:
|
|
621
|
+
test_end_session = n_sessions
|
|
622
|
+
|
|
623
|
+
# Calculate train session range
|
|
624
|
+
if self.expanding:
|
|
625
|
+
train_start_session = 0
|
|
626
|
+
else:
|
|
627
|
+
if train_size_sessions is not None:
|
|
628
|
+
train_start_session = max(
|
|
629
|
+
0, test_start_session - self.gap - train_size_sessions
|
|
630
|
+
)
|
|
631
|
+
else:
|
|
632
|
+
train_start_session = 0
|
|
633
|
+
|
|
634
|
+
train_end_session = test_start_session - self.gap
|
|
635
|
+
|
|
636
|
+
# Get session IDs for train and test
|
|
637
|
+
if isinstance(unique_sessions, pl.Series):
|
|
638
|
+
train_sessions = unique_sessions[train_start_session:train_end_session].to_list()
|
|
639
|
+
test_sessions = unique_sessions[test_start_session:test_end_session].to_list()
|
|
640
|
+
session_col_values = X[self.session_col]
|
|
641
|
+
else: # pandas Series
|
|
642
|
+
train_sessions = unique_sessions.iloc[
|
|
643
|
+
train_start_session:train_end_session
|
|
644
|
+
].tolist()
|
|
645
|
+
test_sessions = unique_sessions.iloc[test_start_session:test_end_session].tolist()
|
|
646
|
+
session_col_values = X[self.session_col]
|
|
647
|
+
|
|
648
|
+
# Map sessions to row indices
|
|
649
|
+
if isinstance(X, pl.DataFrame):
|
|
650
|
+
train_mask = session_col_values.is_in(train_sessions)
|
|
651
|
+
test_mask = session_col_values.is_in(test_sessions)
|
|
652
|
+
train_indices = np.where(train_mask.to_numpy())[0]
|
|
653
|
+
test_indices = np.where(test_mask.to_numpy())[0]
|
|
654
|
+
else: # pandas DataFrame
|
|
655
|
+
# Cast to pd.Series since X is pd.DataFrame here
|
|
656
|
+
session_col_pd = cast(pd.Series, session_col_values)
|
|
657
|
+
train_mask = session_col_pd.isin(train_sessions)
|
|
658
|
+
test_mask = session_col_pd.isin(test_sessions)
|
|
659
|
+
train_indices = np.where(train_mask.to_numpy())[0]
|
|
660
|
+
test_indices = np.where(test_mask.to_numpy())[0]
|
|
661
|
+
|
|
662
|
+
# Apply purging and embargo if configured
|
|
663
|
+
if self._has_purging_or_embargo():
|
|
664
|
+
# Compute actual timestamp bounds from test indices
|
|
665
|
+
# This is critical for multi-asset data where rows may be sorted by
|
|
666
|
+
# asset rather than time - using positional indices [0] and [-1] would
|
|
667
|
+
# give incorrect timestamp bounds
|
|
668
|
+
test_start_time, test_end_time = self._timestamp_window_from_indices(
|
|
669
|
+
test_indices, timestamps
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
clean_train_indices = apply_purging_and_embargo(
|
|
673
|
+
train_indices=train_indices,
|
|
674
|
+
test_start=test_start_time,
|
|
675
|
+
test_end=test_end_time,
|
|
676
|
+
label_horizon=self.label_horizon,
|
|
677
|
+
embargo_size=self.embargo_size,
|
|
678
|
+
embargo_pct=self.embargo_pct,
|
|
679
|
+
n_samples=n_samples,
|
|
680
|
+
timestamps=timestamps,
|
|
681
|
+
)
|
|
682
|
+
else:
|
|
683
|
+
clean_train_indices = train_indices
|
|
684
|
+
|
|
685
|
+
# Apply group isolation if requested
|
|
686
|
+
if self.isolate_groups and groups is not None:
|
|
687
|
+
clean_train_indices = isolate_groups_from_train(
|
|
688
|
+
clean_train_indices, test_indices, groups
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
yield clean_train_indices.astype(np.intp), test_indices.astype(np.intp)
|
|
692
|
+
|
|
693
|
+
def _has_purging_or_embargo(self) -> bool:
|
|
694
|
+
"""Check if purging or embargo is needed.
|
|
695
|
+
|
|
696
|
+
Handles both int and pd.Timedelta values for label_horizon and embargo_size.
|
|
697
|
+
|
|
698
|
+
Returns
|
|
699
|
+
-------
|
|
700
|
+
bool
|
|
701
|
+
True if purging or embargo should be applied.
|
|
702
|
+
"""
|
|
703
|
+
# Check label_horizon (can be int or Timedelta)
|
|
704
|
+
has_label_horizon = False
|
|
705
|
+
if isinstance(self.label_horizon, int | float):
|
|
706
|
+
has_label_horizon = self.label_horizon > 0
|
|
707
|
+
elif hasattr(self.label_horizon, "total_seconds"): # pd.Timedelta
|
|
708
|
+
has_label_horizon = self.label_horizon.total_seconds() > 0
|
|
709
|
+
|
|
710
|
+
# Check embargo (embargo_size can be int or Timedelta, embargo_pct is always float or None)
|
|
711
|
+
has_embargo = self.embargo_size is not None or self.embargo_pct is not None
|
|
712
|
+
|
|
713
|
+
return has_label_horizon or has_embargo
|
|
714
|
+
|
|
715
|
+
@staticmethod
|
|
716
|
+
def _timestamp_window_from_indices(
|
|
717
|
+
indices: "NDArray[np.intp]",
|
|
718
|
+
timestamps: pd.DatetimeIndex | None,
|
|
719
|
+
) -> tuple[int | pd.Timestamp, int | pd.Timestamp]:
|
|
720
|
+
"""Compute timestamp window from actual indices (for session-aligned purging).
|
|
721
|
+
|
|
722
|
+
This is critical for correct purging in session-aligned mode. Instead of
|
|
723
|
+
using positional indices [0] and [-1] which assume chronological ordering,
|
|
724
|
+
we compute the actual timestamp bounds from all test indices.
|
|
725
|
+
|
|
726
|
+
For multi-asset data where rows may be sorted by asset rather than time,
|
|
727
|
+
test_indices[0] may not have the minimum timestamp.
|
|
728
|
+
|
|
729
|
+
Parameters
|
|
730
|
+
----------
|
|
731
|
+
indices : ndarray
|
|
732
|
+
Row indices of test samples.
|
|
733
|
+
timestamps : pd.DatetimeIndex or None
|
|
734
|
+
Timestamps for all samples. If None, returns index bounds.
|
|
735
|
+
|
|
736
|
+
Returns
|
|
737
|
+
-------
|
|
738
|
+
start_time : int or pd.Timestamp
|
|
739
|
+
Minimum timestamp of test indices (or min index if no timestamps).
|
|
740
|
+
end_time_exclusive : int or pd.Timestamp
|
|
741
|
+
Maximum timestamp + 1 nanosecond (or max index + 1 if no timestamps).
|
|
742
|
+
"""
|
|
743
|
+
if len(indices) == 0:
|
|
744
|
+
# Empty indices - return minimal bounds
|
|
745
|
+
if timestamps is None:
|
|
746
|
+
return 0, 0
|
|
747
|
+
return timestamps[0], timestamps[0]
|
|
748
|
+
|
|
749
|
+
if timestamps is None:
|
|
750
|
+
# No timestamps - return index bounds
|
|
751
|
+
return int(indices.min()), int(indices.max()) + 1
|
|
752
|
+
|
|
753
|
+
test_timestamps = timestamps.take(indices)
|
|
754
|
+
start_time = test_timestamps.min()
|
|
755
|
+
# Add 1 nanosecond to make end exclusive (handles duplicate timestamps)
|
|
756
|
+
end_time_exclusive = test_timestamps.max() + pd.Timedelta(1, "ns")
|
|
757
|
+
return start_time, end_time_exclusive
|