ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1064 @@
|
|
|
1
|
+
"""Combinatorial Purged Cross-Validation for backtest overfitting detection.
|
|
2
|
+
|
|
3
|
+
This module implements Combinatorial Purged Cross-Validation (CPCV), which generates
|
|
4
|
+
multiple backtest paths by combining different groups of time-series data. This approach
|
|
5
|
+
provides a distribution of performance metrics instead of a single path, enabling robust
|
|
6
|
+
assessment of strategy viability and detection of backtest overfitting.
|
|
7
|
+
|
|
8
|
+
Key Concepts
|
|
9
|
+
------------
|
|
10
|
+
|
|
11
|
+
**Combinatorial Splits**:
|
|
12
|
+
Instead of a single chronological train/test split, CPCV partitions data into N groups
|
|
13
|
+
and generates all C(N,k) combinations of choosing k groups for testing. This creates
|
|
14
|
+
a distribution of backtest results rather than a single path.
|
|
15
|
+
|
|
16
|
+
**Purging**:
|
|
17
|
+
Removes training samples that temporally overlap with test samples within the label
|
|
18
|
+
horizon. Essential for preventing information leakage when labels are forward-looking
|
|
19
|
+
(e.g., future returns). Without purging, the model could train on samples that contain
|
|
20
|
+
information about test set labels.
|
|
21
|
+
|
|
22
|
+
**Embargo**:
|
|
23
|
+
Creates a buffer period after each test group where training samples are removed.
|
|
24
|
+
Accounts for serial correlation in financial data and prevents training on samples
|
|
25
|
+
that are too close in time to the test set. Can be specified as absolute time
|
|
26
|
+
(embargo_size) or as a percentage of total samples (embargo_pct).
|
|
27
|
+
|
|
28
|
+
**Session Alignment**:
|
|
29
|
+
Optionally aligns group boundaries to trading session boundaries rather than arbitrary
|
|
30
|
+
indices. Ensures groups represent complete trading days/sessions, which is important
|
|
31
|
+
for intraday strategies.
|
|
32
|
+
|
|
33
|
+
**Multi-Asset Isolation**:
|
|
34
|
+
When groups parameter is provided, CPCV applies purging per asset independently.
|
|
35
|
+
This prevents cross-asset information leakage and enables proper validation of
|
|
36
|
+
multi-asset strategies.
|
|
37
|
+
|
|
38
|
+
Usage Example
|
|
39
|
+
-------------
|
|
40
|
+
Basic usage with purging and embargo::
|
|
41
|
+
|
|
42
|
+
import polars as pl
|
|
43
|
+
from ml4t.diagnostic.splitters import CombinatorialPurgedCV
|
|
44
|
+
|
|
45
|
+
# Load your time-series data
|
|
46
|
+
df = pl.read_parquet("features.parquet")
|
|
47
|
+
X = df.select(["feature1", "feature2", "feature3"])
|
|
48
|
+
y = df["target"]
|
|
49
|
+
|
|
50
|
+
# Configure CPCV with purging for 5-day forward labels
|
|
51
|
+
# and 2-day embargo to account for autocorrelation
|
|
52
|
+
cv = CombinatorialPurgedCV(
|
|
53
|
+
n_groups=8, # Split into 8 time groups
|
|
54
|
+
n_test_groups=2, # Use 2 groups for testing in each combination
|
|
55
|
+
label_horizon=5, # Labels look forward 5 samples
|
|
56
|
+
embargo_size=2, # Add 2-sample buffer after test set
|
|
57
|
+
max_combinations=20 # Limit to 20 combinations for efficiency
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Generate train/test splits
|
|
61
|
+
for fold, (train_idx, test_idx) in enumerate(cv.split(X)):
|
|
62
|
+
X_train, X_test = X[train_idx], X[test_idx]
|
|
63
|
+
y_train, y_test = y[train_idx], y[test_idx]
|
|
64
|
+
|
|
65
|
+
# Train and evaluate your model
|
|
66
|
+
model.fit(X_train, y_train)
|
|
67
|
+
score = model.score(X_test, y_test)
|
|
68
|
+
print(f"Fold {fold}: Score={score:.4f}")
|
|
69
|
+
|
|
70
|
+
Multi-asset usage with per-asset purging::
|
|
71
|
+
|
|
72
|
+
# For multi-asset strategies, provide asset IDs as groups
|
|
73
|
+
assets = df["symbol"] # e.g., ["AAPL", "MSFT", "GOOGL", ...]
|
|
74
|
+
|
|
75
|
+
cv = CombinatorialPurgedCV(
|
|
76
|
+
n_groups=6,
|
|
77
|
+
n_test_groups=2,
|
|
78
|
+
label_horizon=5,
|
|
79
|
+
embargo_size=2,
|
|
80
|
+
isolate_groups=True # Prevent same asset in train and test
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
for train_idx, test_idx in cv.split(X, groups=assets):
|
|
84
|
+
# CPCV automatically applies per-asset purging
|
|
85
|
+
# Each asset's data is purged independently
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
Session-aligned usage for intraday strategies::
|
|
89
|
+
|
|
90
|
+
import pandas as pd
|
|
91
|
+
|
|
92
|
+
# Data with session_date column from qdata.sessions
|
|
93
|
+
df = pd.read_parquet("intraday_features.parquet")
|
|
94
|
+
# df has columns: timestamp, session_date, feature1, feature2, ...
|
|
95
|
+
|
|
96
|
+
cv = CombinatorialPurgedCV(
|
|
97
|
+
n_groups=10,
|
|
98
|
+
n_test_groups=2,
|
|
99
|
+
label_horizon=pd.Timedelta(minutes=30), # 30-minute forward labels
|
|
100
|
+
embargo_size=pd.Timedelta(minutes=15), # 15-minute embargo
|
|
101
|
+
align_to_sessions=True, # Align groups to sessions
|
|
102
|
+
session_col="session_date" # Column with session IDs
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
for train_idx, test_idx in cv.split(df):
|
|
106
|
+
# Group boundaries now align to complete trading sessions
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
References
|
|
110
|
+
----------
|
|
111
|
+
.. [1] Bailey, D. H., Borwein, J., López de Prado, M., & Zhu, Q. J. (2014).
|
|
112
|
+
"The Probability of Backtest Overfitting." Journal of Computational Finance.
|
|
113
|
+
|
|
114
|
+
.. [2] López de Prado, M. (2018). "Advances in Financial Machine Learning."
|
|
115
|
+
Wiley. Chapter 7: Cross-Validation in Finance.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
from __future__ import annotations
|
|
119
|
+
|
|
120
|
+
import math
|
|
121
|
+
from collections.abc import Generator
|
|
122
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
123
|
+
|
|
124
|
+
import numpy as np
|
|
125
|
+
import pandas as pd
|
|
126
|
+
import polars as pl
|
|
127
|
+
|
|
128
|
+
from ml4t.diagnostic.backends.adapter import DataFrameAdapter
|
|
129
|
+
from ml4t.diagnostic.splitters.base import BaseSplitter
|
|
130
|
+
from ml4t.diagnostic.splitters.config import CombinatorialPurgedConfig
|
|
131
|
+
from ml4t.diagnostic.splitters.cpcv import (
|
|
132
|
+
apply_multi_asset_purging,
|
|
133
|
+
apply_single_asset_purging,
|
|
134
|
+
create_contiguous_partitions,
|
|
135
|
+
create_session_partitions,
|
|
136
|
+
iter_combinations,
|
|
137
|
+
timestamp_window_from_indices,
|
|
138
|
+
validate_contiguous_partitions,
|
|
139
|
+
)
|
|
140
|
+
from ml4t.diagnostic.splitters.group_isolation import isolate_groups_from_train
|
|
141
|
+
|
|
142
|
+
if TYPE_CHECKING:
|
|
143
|
+
from numpy.typing import NDArray
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class CombinatorialPurgedCV(BaseSplitter):
|
|
147
|
+
"""Combinatorial Purged Cross-Validation for backtest overfitting detection.
|
|
148
|
+
|
|
149
|
+
CPCV partitions the time series into N contiguous groups and forms all combinations
|
|
150
|
+
C(N,k) of choosing k groups for testing. This generates multiple backtest paths
|
|
151
|
+
instead of a single chronological split, providing a robust assessment of strategy
|
|
152
|
+
performance and enabling detection of backtest overfitting.
|
|
153
|
+
|
|
154
|
+
How It Works
|
|
155
|
+
------------
|
|
156
|
+
|
|
157
|
+
1. **Partitioning**: Divide time-series data into N contiguous groups of equal size
|
|
158
|
+
2. **Combination Generation**: Generate all C(N,k) combinations of choosing k groups for testing
|
|
159
|
+
3. **Purging**: For each combination, remove training samples that overlap with test labels
|
|
160
|
+
4. **Embargo**: Optionally add buffer periods after test groups to account for autocorrelation
|
|
161
|
+
5. **Multi-Asset Handling**: When groups are provided, apply purging independently per asset
|
|
162
|
+
|
|
163
|
+
Purging Mechanics
|
|
164
|
+
-----------------
|
|
165
|
+
|
|
166
|
+
**Why Purge?**
|
|
167
|
+
When labels are forward-looking (e.g., 5-day returns), training samples near the test
|
|
168
|
+
set temporally overlap with test labels. Without purging, the model trains on information
|
|
169
|
+
about test outcomes, leading to inflated performance estimates.
|
|
170
|
+
|
|
171
|
+
**How Purging Works**:
|
|
172
|
+
For each test group with range [t_start, t_end]:
|
|
173
|
+
|
|
174
|
+
1. Remove train samples where: ``t_train > t_start - label_horizon``
|
|
175
|
+
2. This ensures no training sample's label period overlaps with test samples
|
|
176
|
+
|
|
177
|
+
**Example**::
|
|
178
|
+
|
|
179
|
+
Test group: samples 100-119 (20 samples)
|
|
180
|
+
Label horizon: 5 samples
|
|
181
|
+
Purging removes: training samples 95-99
|
|
182
|
+
Reason: Sample 95's label looks forward to sample 100 (first test sample)
|
|
183
|
+
|
|
184
|
+
Embargo Mechanics
|
|
185
|
+
-----------------
|
|
186
|
+
|
|
187
|
+
**Why Embargo?**
|
|
188
|
+
Financial data exhibits serial correlation - adjacent samples are not independent.
|
|
189
|
+
Even with purging, training on samples immediately before the test set can leak
|
|
190
|
+
information through autocorrelation.
|
|
191
|
+
|
|
192
|
+
**How Embargo Works**:
|
|
193
|
+
After purging, additionally remove a buffer of samples immediately after each test group:
|
|
194
|
+
|
|
195
|
+
- **embargo_size**: Absolute number of samples (e.g., 10 samples)
|
|
196
|
+
- **embargo_pct**: Percentage of total samples (e.g., 0.01 = 1%)
|
|
197
|
+
|
|
198
|
+
**Example**::
|
|
199
|
+
|
|
200
|
+
Test group: samples 100-119
|
|
201
|
+
Embargo: 5 samples
|
|
202
|
+
Additional removal: training samples 120-124
|
|
203
|
+
Result: Creates 5-sample buffer after test group
|
|
204
|
+
|
|
205
|
+
Multi-Asset Purging
|
|
206
|
+
-------------------
|
|
207
|
+
|
|
208
|
+
When ``groups`` parameter is provided (e.g., asset symbols), CPCV applies purging
|
|
209
|
+
independently for each asset. This prevents cross-asset leakage:
|
|
210
|
+
|
|
211
|
+
**Process**:
|
|
212
|
+
1. For each asset, find its training and test samples
|
|
213
|
+
2. Apply purging/embargo only to that asset's data
|
|
214
|
+
3. Combine results across all assets
|
|
215
|
+
|
|
216
|
+
**Why Important?**
|
|
217
|
+
Without per-asset purging, information could leak between assets that trade at
|
|
218
|
+
different times (e.g., European markets vs US markets).
|
|
219
|
+
|
|
220
|
+
Based on Bailey et al. (2014) "The Probability of Backtest Overfitting" and
|
|
221
|
+
López de Prado (2018) "Advances in Financial Machine Learning".
|
|
222
|
+
|
|
223
|
+
Parameters
|
|
224
|
+
----------
|
|
225
|
+
n_groups : int, default=8
|
|
226
|
+
Number of contiguous groups to partition the time series into.
|
|
227
|
+
|
|
228
|
+
n_test_groups : int, default=2
|
|
229
|
+
Number of groups to use for testing in each combination.
|
|
230
|
+
|
|
231
|
+
label_horizon : int or pd.Timedelta, default=0
|
|
232
|
+
Forward-looking period of labels for purging calculation.
|
|
233
|
+
|
|
234
|
+
embargo_size : int or pd.Timedelta, optional
|
|
235
|
+
Size of embargo period after each test group.
|
|
236
|
+
|
|
237
|
+
embargo_pct : float, optional
|
|
238
|
+
Embargo size as percentage of total samples.
|
|
239
|
+
|
|
240
|
+
max_combinations : int, optional
|
|
241
|
+
Maximum number of combinations to generate. If None, generates all C(N,k).
|
|
242
|
+
Use this to limit computational cost for large N.
|
|
243
|
+
|
|
244
|
+
random_state : int, optional
|
|
245
|
+
Random seed for combination sampling when max_combinations is set.
|
|
246
|
+
|
|
247
|
+
align_to_sessions : bool, default=False
|
|
248
|
+
If True, align group boundaries to trading session boundaries.
|
|
249
|
+
Requires X to have a session column (specified by session_col parameter).
|
|
250
|
+
|
|
251
|
+
Trading sessions should be assigned using the qdata library before cross-validation:
|
|
252
|
+
- Use DataManager with exchange/calendar parameters, or
|
|
253
|
+
- Use SessionAssigner.from_exchange('CME') directly
|
|
254
|
+
|
|
255
|
+
session_col : str, default='session_date'
|
|
256
|
+
Name of the column containing session identifiers.
|
|
257
|
+
Only used if align_to_sessions=True.
|
|
258
|
+
This column should be added by qdata.sessions.SessionAssigner
|
|
259
|
+
|
|
260
|
+
isolate_groups : bool, default=True
|
|
261
|
+
If True, prevent the same group (asset/symbol) from appearing in both
|
|
262
|
+
train and test sets. This is enabled by default for CPCV as it's designed
|
|
263
|
+
for multi-asset validation.
|
|
264
|
+
|
|
265
|
+
Requires passing `groups` parameter to split() method with asset IDs.
|
|
266
|
+
|
|
267
|
+
Note: CPCV already applies per-asset purging when groups are provided.
|
|
268
|
+
This parameter provides additional group isolation guarantee.
|
|
269
|
+
|
|
270
|
+
Attributes:
|
|
271
|
+
----------
|
|
272
|
+
n_groups_ : int
|
|
273
|
+
The number of groups.
|
|
274
|
+
|
|
275
|
+
n_test_groups_ : int
|
|
276
|
+
The number of test groups.
|
|
277
|
+
|
|
278
|
+
Examples:
|
|
279
|
+
--------
|
|
280
|
+
>>> import numpy as np
|
|
281
|
+
>>> from ml4t.diagnostic.splitters import CombinatorialPurgedCV
|
|
282
|
+
>>> X = np.arange(200).reshape(200, 1)
|
|
283
|
+
>>> cv = CombinatorialPurgedCV(n_groups=6, n_test_groups=2, label_horizon=5)
|
|
284
|
+
>>> combinations = list(cv.split(X))
|
|
285
|
+
>>> print(f"Generated {len(combinations)} combinations")
|
|
286
|
+
Generated 15 combinations
|
|
287
|
+
|
|
288
|
+
>>> # Each combination provides train/test indices
|
|
289
|
+
>>> for i, (train, test) in enumerate(combinations[:3]):
|
|
290
|
+
... print(f"Combination {i+1}: Train={len(train)}, Test={len(test)}")
|
|
291
|
+
Combination 1: Train=125, Test=50
|
|
292
|
+
Combination 2: Train=125, Test=50
|
|
293
|
+
Combination 3: Train=125, Test=50
|
|
294
|
+
|
|
295
|
+
Notes:
|
|
296
|
+
-----
|
|
297
|
+
The total number of combinations is C(n_groups, n_test_groups). For large values,
|
|
298
|
+
this can become computationally expensive:
|
|
299
|
+
- C(8,2) = 28 combinations
|
|
300
|
+
- C(10,3) = 120 combinations
|
|
301
|
+
- C(12,4) = 495 combinations
|
|
302
|
+
|
|
303
|
+
Use max_combinations to limit computational cost for large datasets.
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
def __init__(
|
|
307
|
+
self,
|
|
308
|
+
config: CombinatorialPurgedConfig | None = None,
|
|
309
|
+
*,
|
|
310
|
+
n_groups: int = 8,
|
|
311
|
+
n_test_groups: int = 2,
|
|
312
|
+
label_horizon: int | pd.Timedelta = 0,
|
|
313
|
+
embargo_size: int | pd.Timedelta | None = None,
|
|
314
|
+
embargo_pct: float | None = None,
|
|
315
|
+
max_combinations: int | None = None,
|
|
316
|
+
random_state: int | None = None,
|
|
317
|
+
align_to_sessions: bool = False,
|
|
318
|
+
session_col: str = "session_date",
|
|
319
|
+
timestamp_col: str | None = None,
|
|
320
|
+
isolate_groups: bool = True,
|
|
321
|
+
) -> None:
|
|
322
|
+
"""Initialize CombinatorialPurgedCV.
|
|
323
|
+
|
|
324
|
+
This splitter uses a config-first architecture. You can either:
|
|
325
|
+
1. Pass a config object: CombinatorialPurgedCV(config=my_config)
|
|
326
|
+
2. Pass individual parameters: CombinatorialPurgedCV(n_groups=8, n_test_groups=2)
|
|
327
|
+
|
|
328
|
+
Parameters are automatically converted to a config object internally,
|
|
329
|
+
ensuring a single source of truth for all validation and logic.
|
|
330
|
+
|
|
331
|
+
Examples
|
|
332
|
+
--------
|
|
333
|
+
>>> # Approach 1: Direct parameters (convenient)
|
|
334
|
+
>>> cv = CombinatorialPurgedCV(n_groups=10, n_test_groups=3)
|
|
335
|
+
>>>
|
|
336
|
+
>>> # Approach 2: Config object (for serialization/reproducibility)
|
|
337
|
+
>>> from ml4t.diagnostic.splitters.config import CombinatorialPurgedConfig
|
|
338
|
+
>>> config = CombinatorialPurgedConfig(n_groups=10, n_test_groups=3)
|
|
339
|
+
>>> cv = CombinatorialPurgedCV(config=config)
|
|
340
|
+
>>>
|
|
341
|
+
>>> # Config can be serialized
|
|
342
|
+
>>> config.to_json("cpcv_config.json")
|
|
343
|
+
>>> loaded = CombinatorialPurgedConfig.from_json("cpcv_config.json")
|
|
344
|
+
>>> cv = CombinatorialPurgedCV(config=loaded)
|
|
345
|
+
"""
|
|
346
|
+
# Config-first: either use provided config or create from params
|
|
347
|
+
if config is not None:
|
|
348
|
+
# Verify no conflicting parameters when config is provided
|
|
349
|
+
self._validate_no_param_conflicts(
|
|
350
|
+
n_groups,
|
|
351
|
+
n_test_groups,
|
|
352
|
+
label_horizon,
|
|
353
|
+
embargo_size,
|
|
354
|
+
embargo_pct,
|
|
355
|
+
max_combinations,
|
|
356
|
+
random_state,
|
|
357
|
+
align_to_sessions,
|
|
358
|
+
session_col,
|
|
359
|
+
timestamp_col,
|
|
360
|
+
isolate_groups,
|
|
361
|
+
)
|
|
362
|
+
self.config = config
|
|
363
|
+
else:
|
|
364
|
+
# Create config from individual parameters
|
|
365
|
+
# Note: embargo validation (mutual exclusivity) handled by config
|
|
366
|
+
self.config = self._create_config_from_params(
|
|
367
|
+
n_groups,
|
|
368
|
+
n_test_groups,
|
|
369
|
+
label_horizon,
|
|
370
|
+
embargo_size,
|
|
371
|
+
embargo_pct,
|
|
372
|
+
max_combinations,
|
|
373
|
+
random_state,
|
|
374
|
+
align_to_sessions,
|
|
375
|
+
session_col,
|
|
376
|
+
timestamp_col,
|
|
377
|
+
isolate_groups,
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
# Use parameter if provided, otherwise use config value
|
|
381
|
+
# This allows random_state to be passed either via config or direct parameter
|
|
382
|
+
self.random_state = random_state if random_state is not None else self.config.random_state
|
|
383
|
+
|
|
384
|
+
def _validate_no_param_conflicts(
|
|
385
|
+
self,
|
|
386
|
+
n_groups: int,
|
|
387
|
+
n_test_groups: int,
|
|
388
|
+
label_horizon: int | pd.Timedelta,
|
|
389
|
+
embargo_size: int | pd.Timedelta | None,
|
|
390
|
+
embargo_pct: float | None,
|
|
391
|
+
max_combinations: int | None,
|
|
392
|
+
random_state: int | None,
|
|
393
|
+
align_to_sessions: bool,
|
|
394
|
+
session_col: str,
|
|
395
|
+
timestamp_col: str | None,
|
|
396
|
+
isolate_groups: bool,
|
|
397
|
+
) -> None:
|
|
398
|
+
"""Validate no conflicting parameters when config is provided."""
|
|
399
|
+
|
|
400
|
+
def is_semantically_default(value: Any, default: Any) -> bool:
|
|
401
|
+
"""Check if value is semantically equal to default.
|
|
402
|
+
|
|
403
|
+
Handles heterogeneous types:
|
|
404
|
+
- pd.Timedelta(0) is semantically equal to 0
|
|
405
|
+
- np.int64(0) is semantically equal to 0
|
|
406
|
+
- None equals None
|
|
407
|
+
"""
|
|
408
|
+
if value is None and default is None:
|
|
409
|
+
return True
|
|
410
|
+
if value is None or default is None:
|
|
411
|
+
return False
|
|
412
|
+
# Handle Timedelta vs int comparison for label_horizon/embargo_size
|
|
413
|
+
if isinstance(value, pd.Timedelta):
|
|
414
|
+
if isinstance(default, int) and default == 0:
|
|
415
|
+
return value == pd.Timedelta(0)
|
|
416
|
+
return value == default
|
|
417
|
+
if isinstance(default, pd.Timedelta):
|
|
418
|
+
if isinstance(value, int) and value == 0:
|
|
419
|
+
return default == pd.Timedelta(0)
|
|
420
|
+
return value == default
|
|
421
|
+
# Handle numpy int types vs Python int
|
|
422
|
+
try:
|
|
423
|
+
return bool(value == default)
|
|
424
|
+
except (TypeError, ValueError):
|
|
425
|
+
return False
|
|
426
|
+
|
|
427
|
+
# Check for non-default parameter values
|
|
428
|
+
# Note: random_state is NOT in this list because it's now in config.
|
|
429
|
+
# Users can pass random_state as a parameter to override config.random_state.
|
|
430
|
+
param_checks = [
|
|
431
|
+
("n_groups", n_groups, 8),
|
|
432
|
+
("n_test_groups", n_test_groups, 2),
|
|
433
|
+
("label_horizon", label_horizon, 0),
|
|
434
|
+
("embargo_size", embargo_size, None),
|
|
435
|
+
("embargo_pct", embargo_pct, None),
|
|
436
|
+
("max_combinations", max_combinations, None),
|
|
437
|
+
("align_to_sessions", align_to_sessions, False),
|
|
438
|
+
("session_col", session_col, "session_date"),
|
|
439
|
+
("timestamp_col", timestamp_col, None),
|
|
440
|
+
("isolate_groups", isolate_groups, True),
|
|
441
|
+
]
|
|
442
|
+
|
|
443
|
+
non_default_params = [
|
|
444
|
+
name
|
|
445
|
+
for name, value, default in param_checks
|
|
446
|
+
if not is_semantically_default(value, default)
|
|
447
|
+
]
|
|
448
|
+
|
|
449
|
+
if non_default_params:
|
|
450
|
+
raise ValueError(
|
|
451
|
+
f"Cannot specify both 'config' and individual parameters. "
|
|
452
|
+
f"Got config plus: {', '.join(non_default_params)}"
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
def _create_config_from_params(
|
|
456
|
+
self,
|
|
457
|
+
n_groups: int,
|
|
458
|
+
n_test_groups: int,
|
|
459
|
+
label_horizon: int | pd.Timedelta,
|
|
460
|
+
embargo_size: int | pd.Timedelta | None,
|
|
461
|
+
embargo_pct: float | None,
|
|
462
|
+
max_combinations: int | None,
|
|
463
|
+
random_state: int | None,
|
|
464
|
+
align_to_sessions: bool,
|
|
465
|
+
session_col: str,
|
|
466
|
+
timestamp_col: str | None,
|
|
467
|
+
isolate_groups: bool,
|
|
468
|
+
) -> CombinatorialPurgedConfig:
|
|
469
|
+
"""Create config object from individual parameters."""
|
|
470
|
+
return CombinatorialPurgedConfig(
|
|
471
|
+
n_groups=n_groups,
|
|
472
|
+
n_test_groups=n_test_groups,
|
|
473
|
+
label_horizon=label_horizon,
|
|
474
|
+
embargo_td=embargo_size,
|
|
475
|
+
embargo_pct=embargo_pct,
|
|
476
|
+
max_combinations=max_combinations,
|
|
477
|
+
random_state=random_state,
|
|
478
|
+
align_to_sessions=align_to_sessions,
|
|
479
|
+
session_col=session_col,
|
|
480
|
+
timestamp_col=timestamp_col,
|
|
481
|
+
isolate_groups=isolate_groups,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
# Property accessors for config values (clean API)
|
|
485
|
+
@property
|
|
486
|
+
def n_groups(self) -> int:
|
|
487
|
+
"""Number of groups to partition timeline into."""
|
|
488
|
+
return self.config.n_groups
|
|
489
|
+
|
|
490
|
+
@property
|
|
491
|
+
def n_test_groups(self) -> int:
|
|
492
|
+
"""Number of groups per test set."""
|
|
493
|
+
return self.config.n_test_groups
|
|
494
|
+
|
|
495
|
+
@property
|
|
496
|
+
def label_horizon(self) -> int | pd.Timedelta:
|
|
497
|
+
"""Forward-looking period of labels (int samples or Timedelta)."""
|
|
498
|
+
return self.config.label_horizon
|
|
499
|
+
|
|
500
|
+
@property
|
|
501
|
+
def embargo_size(self) -> int | pd.Timedelta | None:
|
|
502
|
+
"""Embargo buffer size (int samples or Timedelta)."""
|
|
503
|
+
return self.config.embargo_td
|
|
504
|
+
|
|
505
|
+
@property
|
|
506
|
+
def embargo_pct(self) -> float | None:
|
|
507
|
+
"""Embargo size as percentage of total samples."""
|
|
508
|
+
return self.config.embargo_pct
|
|
509
|
+
|
|
510
|
+
@property
|
|
511
|
+
def max_combinations(self) -> int | None:
|
|
512
|
+
"""Maximum number of folds to generate."""
|
|
513
|
+
return self.config.max_combinations
|
|
514
|
+
|
|
515
|
+
@property
|
|
516
|
+
def align_to_sessions(self) -> bool:
|
|
517
|
+
"""Whether to align group boundaries to sessions."""
|
|
518
|
+
return self.config.align_to_sessions
|
|
519
|
+
|
|
520
|
+
@property
|
|
521
|
+
def session_col(self) -> str:
|
|
522
|
+
"""Column name containing session identifiers."""
|
|
523
|
+
return self.config.session_col
|
|
524
|
+
|
|
525
|
+
@property
|
|
526
|
+
def timestamp_col(self) -> str | None:
|
|
527
|
+
"""Column name containing timestamps for time-based operations."""
|
|
528
|
+
return self.config.timestamp_col
|
|
529
|
+
|
|
530
|
+
@property
|
|
531
|
+
def isolate_groups(self) -> bool:
|
|
532
|
+
"""Whether to prevent group overlap between train/test."""
|
|
533
|
+
return self.config.isolate_groups
|
|
534
|
+
|
|
535
|
+
def get_n_splits(
|
|
536
|
+
self,
|
|
537
|
+
X: pl.DataFrame | pd.DataFrame | NDArray[Any] | None = None,
|
|
538
|
+
y: pl.Series | pd.Series | NDArray[Any] | None = None,
|
|
539
|
+
groups: pl.Series | pd.Series | NDArray[Any] | None = None,
|
|
540
|
+
) -> int:
|
|
541
|
+
"""Get number of splits (combinations).
|
|
542
|
+
|
|
543
|
+
Parameters
|
|
544
|
+
----------
|
|
545
|
+
X : array-like, optional
|
|
546
|
+
Always ignored, exists for compatibility.
|
|
547
|
+
|
|
548
|
+
y : array-like, optional
|
|
549
|
+
Always ignored, exists for compatibility.
|
|
550
|
+
|
|
551
|
+
groups : array-like, optional
|
|
552
|
+
Always ignored, exists for compatibility.
|
|
553
|
+
|
|
554
|
+
Returns:
|
|
555
|
+
-------
|
|
556
|
+
n_splits : int
|
|
557
|
+
Number of combinations that will be generated.
|
|
558
|
+
"""
|
|
559
|
+
del X, y, groups # Unused, for sklearn compatibility
|
|
560
|
+
total_combinations = math.comb(self.n_groups, self.n_test_groups)
|
|
561
|
+
|
|
562
|
+
if self.max_combinations is None:
|
|
563
|
+
return total_combinations
|
|
564
|
+
return min(self.max_combinations, total_combinations)
|
|
565
|
+
|
|
566
|
+
def split(
|
|
567
|
+
self,
|
|
568
|
+
X: pl.DataFrame | pd.DataFrame | NDArray[Any],
|
|
569
|
+
y: pl.Series | pd.Series | NDArray[Any] | None = None,
|
|
570
|
+
groups: pl.Series | pd.Series | NDArray[Any] | None = None,
|
|
571
|
+
) -> Generator[tuple[NDArray[np.intp], NDArray[np.intp]], None, None]:
|
|
572
|
+
"""Generate train/test indices for combinatorial splits with purging and embargo.
|
|
573
|
+
|
|
574
|
+
This method generates all combinations C(N,k) of train/test splits, applying
|
|
575
|
+
purging and embargo to prevent information leakage. Each yielded split represents
|
|
576
|
+
an independent backtest path.
|
|
577
|
+
|
|
578
|
+
Parameters
|
|
579
|
+
----------
|
|
580
|
+
X : DataFrame or ndarray of shape (n_samples, n_features)
|
|
581
|
+
Training data. Must have a datetime index if using Timedelta-based
|
|
582
|
+
label_horizon or embargo_size.
|
|
583
|
+
|
|
584
|
+
y : Series or ndarray of shape (n_samples,), optional
|
|
585
|
+
Target variable. Not used in splitting logic, but accepted for
|
|
586
|
+
API compatibility with scikit-learn.
|
|
587
|
+
|
|
588
|
+
groups : Series or ndarray of shape (n_samples,), optional
|
|
589
|
+
Group labels for samples (e.g., asset symbols for multi-asset strategies).
|
|
590
|
+
|
|
591
|
+
When provided:
|
|
592
|
+
- Purging is applied independently per group (asset)
|
|
593
|
+
- Prevents information leakage across groups
|
|
594
|
+
- Essential for multi-asset portfolio validation
|
|
595
|
+
|
|
596
|
+
Example: ``groups = df["symbol"]`` # ["AAPL", "MSFT", "GOOGL", ...]
|
|
597
|
+
|
|
598
|
+
Yields
|
|
599
|
+
------
|
|
600
|
+
train : ndarray of shape (n_train_samples,)
|
|
601
|
+
Indices of training samples for this combination.
|
|
602
|
+
Purging and embargo have been applied to remove:
|
|
603
|
+
- Samples overlapping with test labels (purging)
|
|
604
|
+
- Samples in embargo buffer after test groups (embargo)
|
|
605
|
+
|
|
606
|
+
test : ndarray of shape (n_test_samples,)
|
|
607
|
+
Indices of test samples for this combination.
|
|
608
|
+
Consists of samples from the k selected test groups.
|
|
609
|
+
|
|
610
|
+
Raises
|
|
611
|
+
------
|
|
612
|
+
ValueError
|
|
613
|
+
If X has incompatible shape or missing required columns
|
|
614
|
+
(e.g., session_col when align_to_sessions=True).
|
|
615
|
+
|
|
616
|
+
TypeError
|
|
617
|
+
If X index is not datetime when using Timedelta parameters.
|
|
618
|
+
|
|
619
|
+
Notes
|
|
620
|
+
-----
|
|
621
|
+
**Number of Combinations**:
|
|
622
|
+
Generates C(n_groups, n_test_groups) combinations. For example:
|
|
623
|
+
- C(8,2) = 28 combinations
|
|
624
|
+
- C(10,3) = 120 combinations
|
|
625
|
+
- C(12,4) = 495 combinations
|
|
626
|
+
|
|
627
|
+
Use ``max_combinations`` parameter to limit the number of splits generated.
|
|
628
|
+
|
|
629
|
+
**Purging Logic**:
|
|
630
|
+
For each test group:
|
|
631
|
+
1. Identify test sample range [t_start, t_end]
|
|
632
|
+
2. Remove training samples where: t_train > t_start - label_horizon
|
|
633
|
+
3. This prevents training on samples whose labels overlap with test period
|
|
634
|
+
|
|
635
|
+
**Embargo Logic**:
|
|
636
|
+
After purging, additionally remove training samples:
|
|
637
|
+
- In range [t_end + 1, t_end + embargo_size]
|
|
638
|
+
- This accounts for serial correlation in financial time series
|
|
639
|
+
|
|
640
|
+
**Multi-Asset Handling**:
|
|
641
|
+
When ``groups`` is provided:
|
|
642
|
+
1. For each asset, find its training and test indices
|
|
643
|
+
2. Apply purging/embargo independently to that asset's data
|
|
644
|
+
3. Combine purged results across all assets
|
|
645
|
+
4. This prevents cross-asset information leakage
|
|
646
|
+
|
|
647
|
+
**Session Alignment**:
|
|
648
|
+
When ``align_to_sessions=True``:
|
|
649
|
+
- Group boundaries align to trading session boundaries
|
|
650
|
+
- Ensures each group contains complete trading days/sessions
|
|
651
|
+
- Requires X to have column specified by ``session_col`` parameter
|
|
652
|
+
|
|
653
|
+
Examples
|
|
654
|
+
--------
|
|
655
|
+
Basic usage with purging::
|
|
656
|
+
|
|
657
|
+
>>> import polars as pl
|
|
658
|
+
>>> from ml4t.diagnostic.splitters import CombinatorialPurgedCV
|
|
659
|
+
>>>
|
|
660
|
+
>>> # Create sample data
|
|
661
|
+
>>> n = 1000
|
|
662
|
+
>>> X = pl.DataFrame({"feature1": range(n), "feature2": range(n, 2*n)})
|
|
663
|
+
>>> y = pl.Series(range(n))
|
|
664
|
+
>>>
|
|
665
|
+
>>> # Configure CPCV
|
|
666
|
+
>>> cv = CombinatorialPurgedCV(
|
|
667
|
+
... n_groups=8,
|
|
668
|
+
... n_test_groups=2,
|
|
669
|
+
... label_horizon=5,
|
|
670
|
+
... embargo_size=2
|
|
671
|
+
... )
|
|
672
|
+
>>>
|
|
673
|
+
>>> # Generate splits
|
|
674
|
+
>>> for fold, (train_idx, test_idx) in enumerate(cv.split(X)):
|
|
675
|
+
... print(f"Fold {fold}: Train={len(train_idx)}, Test={len(test_idx)}")
|
|
676
|
+
Fold 0: Train=739, Test=250
|
|
677
|
+
Fold 1: Train=739, Test=250
|
|
678
|
+
...
|
|
679
|
+
|
|
680
|
+
Multi-asset usage::
|
|
681
|
+
|
|
682
|
+
>>> # Multi-asset data with symbol column
|
|
683
|
+
>>> symbols = pl.Series(["AAPL"] * 250 + ["MSFT"] * 250 +
|
|
684
|
+
... ["GOOGL"] * 250 + ["AMZN"] * 250)
|
|
685
|
+
>>>
|
|
686
|
+
>>> cv = CombinatorialPurgedCV(
|
|
687
|
+
... n_groups=6,
|
|
688
|
+
... n_test_groups=2,
|
|
689
|
+
... label_horizon=5,
|
|
690
|
+
... embargo_size=2,
|
|
691
|
+
... isolate_groups=True
|
|
692
|
+
... )
|
|
693
|
+
>>>
|
|
694
|
+
>>> for train_idx, test_idx in cv.split(X, groups=symbols):
|
|
695
|
+
... # Purging applied independently per asset
|
|
696
|
+
... train_symbols = symbols[train_idx].unique()
|
|
697
|
+
... test_symbols = symbols[test_idx].unique()
|
|
698
|
+
|
|
699
|
+
Session-aligned usage::
|
|
700
|
+
|
|
701
|
+
>>> import pandas as pd
|
|
702
|
+
>>>
|
|
703
|
+
>>> # Intraday data with session dates
|
|
704
|
+
>>> df = pd.DataFrame({
|
|
705
|
+
... "timestamp": pd.date_range("2024-01-01", periods=1000, freq="1min"),
|
|
706
|
+
... "session_date": pd.date_range("2024-01-01", periods=1000, freq="1min").date,
|
|
707
|
+
... "feature1": range(1000)
|
|
708
|
+
... })
|
|
709
|
+
>>>
|
|
710
|
+
>>> cv = CombinatorialPurgedCV(
|
|
711
|
+
... n_groups=10,
|
|
712
|
+
... n_test_groups=2,
|
|
713
|
+
... label_horizon=pd.Timedelta(minutes=30),
|
|
714
|
+
... embargo_size=pd.Timedelta(minutes=15),
|
|
715
|
+
... align_to_sessions=True,
|
|
716
|
+
... session_col="session_date"
|
|
717
|
+
... )
|
|
718
|
+
>>>
|
|
719
|
+
>>> for train_idx, test_idx in cv.split(df):
|
|
720
|
+
... # Group boundaries aligned to session boundaries
|
|
721
|
+
... pass
|
|
722
|
+
|
|
723
|
+
See Also
|
|
724
|
+
--------
|
|
725
|
+
CombinatorialPurgedConfig : Configuration object for CPCV parameters
|
|
726
|
+
apply_purging_and_embargo : Low-level purging/embargo function
|
|
727
|
+
BaseSplitter : Base class for all splitters
|
|
728
|
+
"""
|
|
729
|
+
# Validate inputs (no numpy conversion - performance optimization)
|
|
730
|
+
n_samples = self._validate_inputs(X, y, groups)
|
|
731
|
+
|
|
732
|
+
# Validate session alignment if enabled
|
|
733
|
+
self._validate_session_alignment(X, self.align_to_sessions, self.session_col)
|
|
734
|
+
|
|
735
|
+
# Extract timestamps if available (supports both Polars and pandas)
|
|
736
|
+
timestamps = self._extract_timestamps(X, self.timestamp_col)
|
|
737
|
+
|
|
738
|
+
# Create group indices or boundaries
|
|
739
|
+
# For session-aligned mode, we need exact indices (not boundaries) to handle
|
|
740
|
+
# non-contiguous/interleaved data correctly
|
|
741
|
+
if self.align_to_sessions:
|
|
742
|
+
# align_to_sessions requires X to be a DataFrame (validation enforces this)
|
|
743
|
+
# Use new method that returns exact indices per group
|
|
744
|
+
group_indices_list = self._create_session_group_indices(
|
|
745
|
+
cast(pl.DataFrame | pd.DataFrame, X)
|
|
746
|
+
)
|
|
747
|
+
use_exact_indices = True
|
|
748
|
+
# Also create boundaries for backward compatibility with purging logic
|
|
749
|
+
group_boundaries = [
|
|
750
|
+
(int(indices[0]), int(indices[-1]) + 1) if len(indices) > 0 else (0, 0)
|
|
751
|
+
for indices in group_indices_list
|
|
752
|
+
]
|
|
753
|
+
else:
|
|
754
|
+
group_boundaries = self._create_group_boundaries(n_samples)
|
|
755
|
+
group_indices_list = None
|
|
756
|
+
use_exact_indices = False
|
|
757
|
+
|
|
758
|
+
# Generate combinations with memory-efficient sampling when max_combinations is set
|
|
759
|
+
# Uses reservoir sampling when needed to avoid materializing all C(n,k) combinations
|
|
760
|
+
combinations = iter_combinations(
|
|
761
|
+
self.n_groups,
|
|
762
|
+
self.n_test_groups,
|
|
763
|
+
self.max_combinations,
|
|
764
|
+
self.random_state,
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
# Generate splits for each combination
|
|
768
|
+
for test_group_indices in combinations:
|
|
769
|
+
# Create test set from selected groups
|
|
770
|
+
if use_exact_indices and group_indices_list is not None:
|
|
771
|
+
# Use exact indices (correct for non-contiguous/interleaved data)
|
|
772
|
+
test_arrays = [group_indices_list[g] for g in test_group_indices]
|
|
773
|
+
test_indices_array = (
|
|
774
|
+
np.concatenate(test_arrays) if test_arrays else np.array([], dtype=np.intp)
|
|
775
|
+
)
|
|
776
|
+
else:
|
|
777
|
+
# Use boundaries with range (only correct for contiguous data)
|
|
778
|
+
test_indices: list[int] = []
|
|
779
|
+
for group_idx in test_group_indices:
|
|
780
|
+
start_idx, end_idx = group_boundaries[group_idx]
|
|
781
|
+
test_indices.extend(range(start_idx, end_idx))
|
|
782
|
+
test_indices_array = np.array(test_indices, dtype=np.intp)
|
|
783
|
+
|
|
784
|
+
# Create initial training set from remaining groups
|
|
785
|
+
train_group_indices_list = [
|
|
786
|
+
i for i in range(self.n_groups) if i not in test_group_indices
|
|
787
|
+
]
|
|
788
|
+
if use_exact_indices and group_indices_list is not None:
|
|
789
|
+
# Use exact indices
|
|
790
|
+
train_arrays = [group_indices_list[g] for g in train_group_indices_list]
|
|
791
|
+
train_indices_array = (
|
|
792
|
+
np.concatenate(train_arrays) if train_arrays else np.array([], dtype=np.intp)
|
|
793
|
+
)
|
|
794
|
+
else:
|
|
795
|
+
# Use boundaries with range
|
|
796
|
+
train_indices: list[int] = []
|
|
797
|
+
for group_idx in train_group_indices_list:
|
|
798
|
+
start_idx, end_idx = group_boundaries[group_idx]
|
|
799
|
+
train_indices.extend(range(start_idx, end_idx))
|
|
800
|
+
train_indices_array = np.array(train_indices, dtype=np.intp)
|
|
801
|
+
|
|
802
|
+
# Apply purging and embargo between test groups and training data
|
|
803
|
+
clean_train_indices = self._apply_group_purging_and_embargo(
|
|
804
|
+
train_indices_array,
|
|
805
|
+
test_group_indices,
|
|
806
|
+
group_boundaries,
|
|
807
|
+
n_samples,
|
|
808
|
+
timestamps,
|
|
809
|
+
groups, # Pass groups for multi-asset awareness
|
|
810
|
+
group_indices_list, # Pass exact indices for session-aligned purging
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
# Apply group isolation if requested
|
|
814
|
+
if self.isolate_groups and groups is not None:
|
|
815
|
+
clean_train_indices = isolate_groups_from_train(
|
|
816
|
+
clean_train_indices, test_indices_array, groups
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
# CPCV Invariant: train set must not be empty after purging
|
|
820
|
+
if len(clean_train_indices) == 0:
|
|
821
|
+
raise ValueError(
|
|
822
|
+
f"CPCV invariant violated: train set is empty after purging/embargo. "
|
|
823
|
+
f"Test groups: {test_group_indices}. "
|
|
824
|
+
f"Consider reducing label_horizon ({self.label_horizon}) or "
|
|
825
|
+
f"embargo_size ({self.embargo_size}) or embargo_pct ({self.embargo_pct})."
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
# CPCV Invariant: train and test sets must be disjoint
|
|
829
|
+
overlap = np.intersect1d(clean_train_indices, test_indices_array)
|
|
830
|
+
if len(overlap) > 0:
|
|
831
|
+
raise ValueError(
|
|
832
|
+
f"CPCV invariant violated: train and test sets have {len(overlap)} "
|
|
833
|
+
f"overlapping indices. First few: {overlap[:5].tolist()}"
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
# Return sorted indices for deterministic behavior
|
|
837
|
+
yield np.sort(clean_train_indices), np.sort(test_indices_array)
|
|
838
|
+
|
|
839
|
+
def _create_group_boundaries(self, n_samples: int) -> list[tuple[int, int]]:
|
|
840
|
+
"""Create boundaries for contiguous groups.
|
|
841
|
+
|
|
842
|
+
Delegates to cpcv.partitioning.create_contiguous_partitions.
|
|
843
|
+
|
|
844
|
+
Parameters
|
|
845
|
+
----------
|
|
846
|
+
n_samples : int
|
|
847
|
+
Total number of samples.
|
|
848
|
+
|
|
849
|
+
Returns:
|
|
850
|
+
-------
|
|
851
|
+
boundaries : list of tuple
|
|
852
|
+
List of (start_idx, end_idx) for each group.
|
|
853
|
+
|
|
854
|
+
Raises
|
|
855
|
+
------
|
|
856
|
+
ValueError
|
|
857
|
+
If boundaries don't satisfy CPCV invariants.
|
|
858
|
+
"""
|
|
859
|
+
return create_contiguous_partitions(n_samples, self.n_groups)
|
|
860
|
+
|
|
861
|
+
def _validate_group_boundaries(self, boundaries: list[tuple[int, int]], n_samples: int) -> None:
|
|
862
|
+
"""Validate CPCV group boundary invariants.
|
|
863
|
+
|
|
864
|
+
Delegates to cpcv.partitioning.validate_contiguous_partitions.
|
|
865
|
+
"""
|
|
866
|
+
validate_contiguous_partitions(boundaries, n_samples)
|
|
867
|
+
|
|
868
|
+
def _create_session_group_indices(
|
|
869
|
+
self,
|
|
870
|
+
X: pl.DataFrame | pd.DataFrame,
|
|
871
|
+
) -> list[NDArray[np.intp]]:
|
|
872
|
+
"""Create exact index arrays per group, aligned to session boundaries.
|
|
873
|
+
|
|
874
|
+
Delegates to cpcv.partitioning.create_session_partitions.
|
|
875
|
+
|
|
876
|
+
Unlike _create_group_boundaries which returns (start, end) ranges suitable
|
|
877
|
+
for contiguous data, this method returns EXACT index arrays for each group.
|
|
878
|
+
This is critical for correct behavior with non-contiguous or interleaved data.
|
|
879
|
+
|
|
880
|
+
Parameters
|
|
881
|
+
----------
|
|
882
|
+
X : DataFrame
|
|
883
|
+
Data with session column.
|
|
884
|
+
|
|
885
|
+
Returns
|
|
886
|
+
-------
|
|
887
|
+
group_indices : list of np.ndarray
|
|
888
|
+
List of numpy arrays containing exact row indices for each group.
|
|
889
|
+
"""
|
|
890
|
+
return create_session_partitions(
|
|
891
|
+
X, self.session_col, self.n_groups, self._session_to_indices
|
|
892
|
+
)
|
|
893
|
+
|
|
894
|
+
@staticmethod
|
|
895
|
+
def _timestamp_window_from_indices(
|
|
896
|
+
indices: NDArray[np.intp],
|
|
897
|
+
timestamps: pd.DatetimeIndex,
|
|
898
|
+
) -> tuple[pd.Timestamp, pd.Timestamp] | None:
|
|
899
|
+
"""Compute timestamp window from actual indices (for session-aligned purging).
|
|
900
|
+
|
|
901
|
+
Delegates to cpcv.windows.timestamp_window_from_indices.
|
|
902
|
+
|
|
903
|
+
Parameters
|
|
904
|
+
----------
|
|
905
|
+
indices : ndarray
|
|
906
|
+
Row indices of test samples.
|
|
907
|
+
timestamps : pd.DatetimeIndex
|
|
908
|
+
Timestamps for all samples.
|
|
909
|
+
|
|
910
|
+
Returns
|
|
911
|
+
-------
|
|
912
|
+
tuple or None
|
|
913
|
+
(start_time, end_time_exclusive) if indices non-empty, None if empty.
|
|
914
|
+
"""
|
|
915
|
+
window = timestamp_window_from_indices(indices, timestamps)
|
|
916
|
+
if window is None:
|
|
917
|
+
return None
|
|
918
|
+
return window.start, window.end_exclusive
|
|
919
|
+
|
|
920
|
+
def _apply_group_purging_and_embargo(
|
|
921
|
+
self,
|
|
922
|
+
train_indices: NDArray[np.intp],
|
|
923
|
+
test_group_indices: tuple[int, ...],
|
|
924
|
+
group_boundaries: list[tuple[int, int]],
|
|
925
|
+
n_samples: int,
|
|
926
|
+
timestamps: pd.DatetimeIndex | None,
|
|
927
|
+
groups: pl.Series | pd.Series | NDArray[Any] | None = None,
|
|
928
|
+
group_indices_list: list[NDArray[np.intp]] | None = None,
|
|
929
|
+
) -> NDArray[np.intp]:
|
|
930
|
+
"""Apply purging and embargo between test groups and training data.
|
|
931
|
+
|
|
932
|
+
This method handles both single-asset and multi-asset scenarios.
|
|
933
|
+
For multi-asset data, purging is applied per asset to prevent
|
|
934
|
+
cross-asset look-ahead bias.
|
|
935
|
+
|
|
936
|
+
Parameters
|
|
937
|
+
----------
|
|
938
|
+
train_indices : ndarray
|
|
939
|
+
Initial training indices.
|
|
940
|
+
|
|
941
|
+
test_group_indices : tuple of int
|
|
942
|
+
Indices of groups used for testing.
|
|
943
|
+
|
|
944
|
+
group_boundaries : list of tuple
|
|
945
|
+
Boundaries of all groups (used for non-session-aligned mode).
|
|
946
|
+
|
|
947
|
+
n_samples : int
|
|
948
|
+
Total number of samples.
|
|
949
|
+
|
|
950
|
+
timestamps : pd.DatetimeIndex, optional
|
|
951
|
+
Timestamps for the data.
|
|
952
|
+
|
|
953
|
+
groups : array-like, optional
|
|
954
|
+
Group labels for multi-asset data (e.g., asset IDs).
|
|
955
|
+
If None, applies single-asset purging logic.
|
|
956
|
+
|
|
957
|
+
group_indices_list : list of ndarray, optional
|
|
958
|
+
Exact indices per group (for session-aligned mode). When provided
|
|
959
|
+
along with timestamps, purging uses actual timestamp bounds instead
|
|
960
|
+
of (min_idx, max_idx) boundaries.
|
|
961
|
+
|
|
962
|
+
Returns:
|
|
963
|
+
-------
|
|
964
|
+
clean_indices : ndarray
|
|
965
|
+
Training indices after purging and embargo.
|
|
966
|
+
"""
|
|
967
|
+
if groups is None:
|
|
968
|
+
# Single-asset case: apply global purging
|
|
969
|
+
return self._apply_single_asset_purging(
|
|
970
|
+
train_indices,
|
|
971
|
+
test_group_indices,
|
|
972
|
+
group_boundaries,
|
|
973
|
+
n_samples,
|
|
974
|
+
timestamps,
|
|
975
|
+
group_indices_list,
|
|
976
|
+
)
|
|
977
|
+
# Multi-asset case: apply per-asset purging
|
|
978
|
+
return self._apply_multi_asset_purging(
|
|
979
|
+
train_indices,
|
|
980
|
+
test_group_indices,
|
|
981
|
+
group_boundaries,
|
|
982
|
+
n_samples,
|
|
983
|
+
timestamps,
|
|
984
|
+
groups,
|
|
985
|
+
group_indices_list,
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
def _apply_single_asset_purging(
|
|
989
|
+
self,
|
|
990
|
+
train_indices: NDArray[np.intp],
|
|
991
|
+
test_group_indices: tuple[int, ...],
|
|
992
|
+
group_boundaries: list[tuple[int, int]],
|
|
993
|
+
n_samples: int,
|
|
994
|
+
timestamps: pd.DatetimeIndex | None,
|
|
995
|
+
group_indices_list: list[NDArray[np.intp]] | None = None,
|
|
996
|
+
) -> NDArray[np.intp]:
|
|
997
|
+
"""Apply purging for single-asset data.
|
|
998
|
+
|
|
999
|
+
Delegates to cpcv.purge_engine.apply_single_asset_purging.
|
|
1000
|
+
"""
|
|
1001
|
+
return apply_single_asset_purging(
|
|
1002
|
+
train_indices=train_indices,
|
|
1003
|
+
test_group_indices=test_group_indices,
|
|
1004
|
+
group_boundaries=group_boundaries,
|
|
1005
|
+
n_samples=n_samples,
|
|
1006
|
+
timestamps=timestamps,
|
|
1007
|
+
label_horizon=self.label_horizon,
|
|
1008
|
+
embargo_size=self.embargo_size,
|
|
1009
|
+
embargo_pct=self.embargo_pct,
|
|
1010
|
+
group_indices_list=group_indices_list,
|
|
1011
|
+
)
|
|
1012
|
+
|
|
1013
|
+
def _apply_multi_asset_purging(
|
|
1014
|
+
self,
|
|
1015
|
+
train_indices: NDArray[np.intp],
|
|
1016
|
+
test_group_indices: tuple[int, ...],
|
|
1017
|
+
group_boundaries: list[tuple[int, int]],
|
|
1018
|
+
n_samples: int,
|
|
1019
|
+
timestamps: pd.DatetimeIndex | None,
|
|
1020
|
+
groups: pl.Series | pd.Series | NDArray[Any],
|
|
1021
|
+
group_indices_list: list[NDArray[np.intp]] | None = None,
|
|
1022
|
+
) -> NDArray[np.intp]:
|
|
1023
|
+
"""Apply purging for multi-asset data with per-asset isolation.
|
|
1024
|
+
|
|
1025
|
+
Delegates to cpcv.purge_engine.apply_multi_asset_purging.
|
|
1026
|
+
"""
|
|
1027
|
+
# Convert groups to numpy array for consistent indexing
|
|
1028
|
+
groups_array = DataFrameAdapter.to_numpy(groups).flatten()
|
|
1029
|
+
|
|
1030
|
+
return apply_multi_asset_purging(
|
|
1031
|
+
train_indices=train_indices,
|
|
1032
|
+
test_group_indices=test_group_indices,
|
|
1033
|
+
group_boundaries=group_boundaries,
|
|
1034
|
+
n_samples=n_samples,
|
|
1035
|
+
timestamps=timestamps,
|
|
1036
|
+
groups_array=groups_array,
|
|
1037
|
+
label_horizon=self.label_horizon,
|
|
1038
|
+
embargo_size=self.embargo_size,
|
|
1039
|
+
embargo_pct=self.embargo_pct,
|
|
1040
|
+
group_indices_list=group_indices_list,
|
|
1041
|
+
)
|
|
1042
|
+
|
|
1043
|
+
def _validate_inputs(
|
|
1044
|
+
self,
|
|
1045
|
+
X: pl.DataFrame | pd.DataFrame | NDArray[Any],
|
|
1046
|
+
y: pl.Series | pd.Series | NDArray[Any] | None = None,
|
|
1047
|
+
groups: pl.Series | pd.Series | NDArray[Any] | None = None,
|
|
1048
|
+
) -> int:
|
|
1049
|
+
"""Validate input shapes and return number of samples.
|
|
1050
|
+
|
|
1051
|
+
Unlike the previous implementation, this does NOT convert to numpy
|
|
1052
|
+
for performance - just validates shapes directly.
|
|
1053
|
+
"""
|
|
1054
|
+
# Use base class validation (handles all input types efficiently)
|
|
1055
|
+
n_samples = self._validate_data(X, y, groups)
|
|
1056
|
+
|
|
1057
|
+
# Validate minimum samples per group
|
|
1058
|
+
min_samples_per_group = n_samples // self.n_groups
|
|
1059
|
+
if min_samples_per_group < 1:
|
|
1060
|
+
raise ValueError(
|
|
1061
|
+
f"Not enough samples ({n_samples}) for {self.n_groups} groups. Need at least {self.n_groups} samples.",
|
|
1062
|
+
)
|
|
1063
|
+
|
|
1064
|
+
return n_samples
|