ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
"""Group partitioning strategies for CPCV.
|
|
2
|
+
|
|
3
|
+
This module handles partitioning the timeline into groups:
|
|
4
|
+
- Contiguous partitioning (equal-sized time slices)
|
|
5
|
+
- Session-aligned partitioning (respects trading session boundaries)
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from typing import TYPE_CHECKING, Any
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
from numpy.typing import NDArray
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
import pandas as pd
|
|
18
|
+
import polars as pl
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_contiguous_partitions(
|
|
22
|
+
n_samples: int,
|
|
23
|
+
n_groups: int,
|
|
24
|
+
) -> list[tuple[int, int]]:
|
|
25
|
+
"""Create boundaries for contiguous groups.
|
|
26
|
+
|
|
27
|
+
Partitions n_samples into n_groups approximately equal-sized groups.
|
|
28
|
+
Earlier groups get extra samples when n_samples is not evenly divisible.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
n_samples : int
|
|
33
|
+
Total number of samples.
|
|
34
|
+
n_groups : int
|
|
35
|
+
Number of groups to create.
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
boundaries : list of tuple
|
|
40
|
+
List of (start_idx, end_idx) for each group.
|
|
41
|
+
end_idx is exclusive (standard Python convention).
|
|
42
|
+
|
|
43
|
+
Raises
|
|
44
|
+
------
|
|
45
|
+
ValueError
|
|
46
|
+
If boundaries don't satisfy CPCV invariants.
|
|
47
|
+
|
|
48
|
+
Examples
|
|
49
|
+
--------
|
|
50
|
+
>>> create_contiguous_partitions(100, 5)
|
|
51
|
+
[(0, 20), (20, 40), (40, 60), (60, 80), (80, 100)]
|
|
52
|
+
|
|
53
|
+
>>> create_contiguous_partitions(103, 5)
|
|
54
|
+
[(0, 21), (21, 42), (42, 62), (62, 82), (82, 103)]
|
|
55
|
+
"""
|
|
56
|
+
base_size = n_samples // n_groups
|
|
57
|
+
remainder = n_samples % n_groups
|
|
58
|
+
|
|
59
|
+
boundaries = []
|
|
60
|
+
current_start = 0
|
|
61
|
+
|
|
62
|
+
for i in range(n_groups):
|
|
63
|
+
# Add extra sample to first 'remainder' groups
|
|
64
|
+
group_size = base_size + (1 if i < remainder else 0)
|
|
65
|
+
group_end = current_start + group_size
|
|
66
|
+
|
|
67
|
+
boundaries.append((current_start, group_end))
|
|
68
|
+
current_start = group_end
|
|
69
|
+
|
|
70
|
+
# Validate invariants
|
|
71
|
+
validate_contiguous_partitions(boundaries, n_samples)
|
|
72
|
+
|
|
73
|
+
return boundaries
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def validate_contiguous_partitions(
|
|
77
|
+
boundaries: list[tuple[int, int]],
|
|
78
|
+
n_samples: int,
|
|
79
|
+
) -> None:
|
|
80
|
+
"""Validate CPCV group boundary invariants.
|
|
81
|
+
|
|
82
|
+
Ensures:
|
|
83
|
+
1. All samples are covered (no gaps)
|
|
84
|
+
2. No overlap between groups
|
|
85
|
+
3. Groups are contiguous
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
boundaries : list of tuple
|
|
90
|
+
List of (start_idx, end_idx) for each group.
|
|
91
|
+
n_samples : int
|
|
92
|
+
Total number of samples.
|
|
93
|
+
|
|
94
|
+
Raises
|
|
95
|
+
------
|
|
96
|
+
ValueError
|
|
97
|
+
If any invariant is violated.
|
|
98
|
+
"""
|
|
99
|
+
if not boundaries:
|
|
100
|
+
raise ValueError("CPCV invariant violated: no group boundaries created")
|
|
101
|
+
|
|
102
|
+
# Check first boundary starts at 0
|
|
103
|
+
if boundaries[0][0] != 0:
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"CPCV invariant violated: first group must start at 0, got {boundaries[0][0]}"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Check last boundary ends at n_samples
|
|
109
|
+
if boundaries[-1][1] != n_samples:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"CPCV invariant violated: last group must end at {n_samples}, got {boundaries[-1][1]}"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Check contiguity (each group starts where previous ended)
|
|
115
|
+
for i in range(1, len(boundaries)):
|
|
116
|
+
prev_end = boundaries[i - 1][1]
|
|
117
|
+
curr_start = boundaries[i][0]
|
|
118
|
+
if curr_start != prev_end:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"CPCV invariant violated: gap between group {i - 1} (ends at {prev_end}) "
|
|
121
|
+
f"and group {i} (starts at {curr_start})"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Check each group is non-empty
|
|
125
|
+
for i, (start, end) in enumerate(boundaries):
|
|
126
|
+
if end <= start:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"CPCV invariant violated: group {i} is empty or invalid (start={start}, end={end})"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def create_session_partitions(
|
|
133
|
+
X: pl.DataFrame | pd.DataFrame,
|
|
134
|
+
session_col: str,
|
|
135
|
+
n_groups: int,
|
|
136
|
+
session_to_indices_fn: Callable[
|
|
137
|
+
[pl.DataFrame | pd.DataFrame, str],
|
|
138
|
+
tuple[list[Any], dict[Any, NDArray[np.intp]]],
|
|
139
|
+
],
|
|
140
|
+
) -> list[NDArray[np.intp]]:
|
|
141
|
+
"""Create exact index arrays per group, aligned to session boundaries.
|
|
142
|
+
|
|
143
|
+
Unlike contiguous partitioning which returns (start, end) ranges,
|
|
144
|
+
this method returns EXACT index arrays for each group. This is critical
|
|
145
|
+
for correct behavior with non-contiguous or interleaved data.
|
|
146
|
+
|
|
147
|
+
Parameters
|
|
148
|
+
----------
|
|
149
|
+
X : DataFrame
|
|
150
|
+
Data with session column.
|
|
151
|
+
session_col : str
|
|
152
|
+
Name of column containing session identifiers.
|
|
153
|
+
n_groups : int
|
|
154
|
+
Number of groups to create.
|
|
155
|
+
session_to_indices_fn : callable
|
|
156
|
+
Function that returns (ordered_sessions, session_to_indices_dict).
|
|
157
|
+
Typically from BaseSplitter._session_to_indices.
|
|
158
|
+
|
|
159
|
+
Returns
|
|
160
|
+
-------
|
|
161
|
+
group_indices : list of np.ndarray
|
|
162
|
+
List of numpy arrays containing exact row indices for each group.
|
|
163
|
+
Each array contains the indices for all rows belonging to sessions
|
|
164
|
+
in that group.
|
|
165
|
+
|
|
166
|
+
Raises
|
|
167
|
+
------
|
|
168
|
+
ValueError
|
|
169
|
+
If not enough sessions for the requested number of groups.
|
|
170
|
+
|
|
171
|
+
Notes
|
|
172
|
+
-----
|
|
173
|
+
The key difference from contiguous partitioning is that we track
|
|
174
|
+
exact indices rather than (start, end) boundaries. This prevents
|
|
175
|
+
incorrect index ranges when data is interleaved by asset within sessions.
|
|
176
|
+
"""
|
|
177
|
+
# Get session -> indices mapping
|
|
178
|
+
ordered_sessions, session_to_indices = session_to_indices_fn(X, session_col)
|
|
179
|
+
n_sessions = len(ordered_sessions)
|
|
180
|
+
|
|
181
|
+
if n_sessions < n_groups:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"Not enough sessions ({n_sessions}) for {n_groups} groups. "
|
|
184
|
+
f"Need at least {n_groups} sessions."
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Partition sessions into groups
|
|
188
|
+
base_sessions_per_group = n_sessions // n_groups
|
|
189
|
+
remainder = n_sessions % n_groups
|
|
190
|
+
|
|
191
|
+
group_indices_list = []
|
|
192
|
+
current_session_idx = 0
|
|
193
|
+
|
|
194
|
+
for i in range(n_groups):
|
|
195
|
+
# Add extra session to first 'remainder' groups
|
|
196
|
+
sessions_in_group = base_sessions_per_group + (1 if i < remainder else 0)
|
|
197
|
+
session_group_end = current_session_idx + sessions_in_group
|
|
198
|
+
|
|
199
|
+
# Get sessions for this group
|
|
200
|
+
group_sessions = ordered_sessions[current_session_idx:session_group_end]
|
|
201
|
+
|
|
202
|
+
# Collect EXACT indices for sessions in this group
|
|
203
|
+
indices_arrays = [session_to_indices[s] for s in group_sessions]
|
|
204
|
+
if indices_arrays:
|
|
205
|
+
group_indices = np.concatenate(indices_arrays)
|
|
206
|
+
# Sort for predictable ordering
|
|
207
|
+
group_indices = np.sort(group_indices)
|
|
208
|
+
else:
|
|
209
|
+
group_indices = np.array([], dtype=np.intp)
|
|
210
|
+
|
|
211
|
+
group_indices_list.append(group_indices)
|
|
212
|
+
current_session_idx = session_group_end
|
|
213
|
+
|
|
214
|
+
return group_indices_list
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def boundaries_to_indices(
|
|
218
|
+
boundaries: list[tuple[int, int]],
|
|
219
|
+
groups: tuple[int, ...],
|
|
220
|
+
) -> NDArray[np.intp]:
|
|
221
|
+
"""Convert group boundaries to flat index array for selected groups.
|
|
222
|
+
|
|
223
|
+
Parameters
|
|
224
|
+
----------
|
|
225
|
+
boundaries : list of tuple
|
|
226
|
+
List of (start_idx, end_idx) for each group.
|
|
227
|
+
groups : tuple of int
|
|
228
|
+
Which groups to include.
|
|
229
|
+
|
|
230
|
+
Returns
|
|
231
|
+
-------
|
|
232
|
+
indices : np.ndarray
|
|
233
|
+
Sorted array of indices for selected groups.
|
|
234
|
+
"""
|
|
235
|
+
# Use numpy concatenation instead of Python list extend for performance
|
|
236
|
+
ranges = [np.arange(boundaries[g][0], boundaries[g][1], dtype=np.intp) for g in groups]
|
|
237
|
+
if not ranges:
|
|
238
|
+
return np.array([], dtype=np.intp)
|
|
239
|
+
return np.concatenate(ranges)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def exact_indices_to_array(
|
|
243
|
+
group_indices_list: list[NDArray[np.intp]],
|
|
244
|
+
groups: tuple[int, ...],
|
|
245
|
+
) -> NDArray[np.intp]:
|
|
246
|
+
"""Concatenate exact index arrays for selected groups.
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
group_indices_list : list of np.ndarray
|
|
251
|
+
List of exact index arrays for each group.
|
|
252
|
+
groups : tuple of int
|
|
253
|
+
Which groups to include.
|
|
254
|
+
|
|
255
|
+
Returns
|
|
256
|
+
-------
|
|
257
|
+
indices : np.ndarray
|
|
258
|
+
Sorted array of indices for selected groups.
|
|
259
|
+
"""
|
|
260
|
+
arrays = [group_indices_list[g] for g in groups]
|
|
261
|
+
if not arrays or all(len(a) == 0 for a in arrays):
|
|
262
|
+
return np.array([], dtype=np.intp)
|
|
263
|
+
return np.sort(np.concatenate(arrays))
|
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
"""Purging engine for CPCV.
|
|
2
|
+
|
|
3
|
+
This module implements the core purging and embargo logic:
|
|
4
|
+
- Mask-based purging (efficient for large datasets)
|
|
5
|
+
- Single-asset and multi-asset purging strategies
|
|
6
|
+
- Segment-based purging for temporal coherence
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import TYPE_CHECKING, Any
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
from numpy.typing import NDArray
|
|
15
|
+
|
|
16
|
+
from ml4t.diagnostic.core.purging import apply_purging_and_embargo
|
|
17
|
+
from ml4t.diagnostic.splitters.cpcv.windows import (
|
|
18
|
+
find_contiguous_segments,
|
|
19
|
+
timestamp_window_from_indices,
|
|
20
|
+
)
|
|
21
|
+
from ml4t.diagnostic.splitters.utils import convert_indices_to_timestamps
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
import pandas as pd
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def apply_single_asset_purging(
|
|
28
|
+
train_indices: NDArray[np.intp],
|
|
29
|
+
test_group_indices: tuple[int, ...],
|
|
30
|
+
group_boundaries: list[tuple[int, int]],
|
|
31
|
+
n_samples: int,
|
|
32
|
+
timestamps: pd.DatetimeIndex | None,
|
|
33
|
+
label_horizon: int | pd.Timedelta,
|
|
34
|
+
embargo_size: int | pd.Timedelta | None,
|
|
35
|
+
embargo_pct: float | None,
|
|
36
|
+
group_indices_list: list[NDArray[np.intp]] | None = None,
|
|
37
|
+
) -> NDArray[np.intp]:
|
|
38
|
+
"""Apply purging for single-asset data.
|
|
39
|
+
|
|
40
|
+
For each test group, removes training samples that would cause
|
|
41
|
+
look-ahead bias due to label overlap or temporal proximity.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
train_indices : ndarray
|
|
46
|
+
Initial training indices.
|
|
47
|
+
test_group_indices : tuple of int
|
|
48
|
+
Indices of groups used for testing.
|
|
49
|
+
group_boundaries : list of tuple
|
|
50
|
+
Boundaries (start, end) for each group.
|
|
51
|
+
n_samples : int
|
|
52
|
+
Total number of samples.
|
|
53
|
+
timestamps : pd.DatetimeIndex, optional
|
|
54
|
+
Timestamps for time-based purging.
|
|
55
|
+
label_horizon : int or pd.Timedelta
|
|
56
|
+
Forward-looking period of labels.
|
|
57
|
+
embargo_size : int or pd.Timedelta, optional
|
|
58
|
+
Buffer period after test set.
|
|
59
|
+
embargo_pct : float, optional
|
|
60
|
+
Embargo as percentage of samples.
|
|
61
|
+
group_indices_list : list of ndarray, optional
|
|
62
|
+
Exact indices per group (for session-aligned mode).
|
|
63
|
+
|
|
64
|
+
Returns
|
|
65
|
+
-------
|
|
66
|
+
clean_indices : ndarray
|
|
67
|
+
Training indices after purging.
|
|
68
|
+
"""
|
|
69
|
+
for test_group_idx in test_group_indices:
|
|
70
|
+
# Compute purge window bounds
|
|
71
|
+
if group_indices_list is not None and timestamps is not None:
|
|
72
|
+
# Session-aligned mode: use actual timestamps from test indices
|
|
73
|
+
test_indices = group_indices_list[test_group_idx]
|
|
74
|
+
window = timestamp_window_from_indices(test_indices, timestamps)
|
|
75
|
+
if window is None:
|
|
76
|
+
# Empty test group - skip purging for this group
|
|
77
|
+
continue
|
|
78
|
+
test_start_time = window.start
|
|
79
|
+
test_end_time = window.end_exclusive
|
|
80
|
+
else:
|
|
81
|
+
# Standard mode: use boundaries
|
|
82
|
+
test_start_idx, test_end_idx = group_boundaries[test_group_idx]
|
|
83
|
+
test_start_time, test_end_time = convert_indices_to_timestamps(
|
|
84
|
+
test_start_idx,
|
|
85
|
+
test_end_idx,
|
|
86
|
+
timestamps,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Apply purging and embargo for this test group
|
|
90
|
+
train_indices = apply_purging_and_embargo(
|
|
91
|
+
train_indices=train_indices,
|
|
92
|
+
test_start=test_start_time,
|
|
93
|
+
test_end=test_end_time,
|
|
94
|
+
label_horizon=label_horizon,
|
|
95
|
+
embargo_size=embargo_size,
|
|
96
|
+
embargo_pct=embargo_pct,
|
|
97
|
+
n_samples=n_samples,
|
|
98
|
+
timestamps=timestamps,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return train_indices
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def apply_multi_asset_purging(
|
|
105
|
+
train_indices: NDArray[np.intp],
|
|
106
|
+
test_group_indices: tuple[int, ...],
|
|
107
|
+
group_boundaries: list[tuple[int, int]],
|
|
108
|
+
n_samples: int,
|
|
109
|
+
timestamps: pd.DatetimeIndex | None,
|
|
110
|
+
groups_array: NDArray[Any],
|
|
111
|
+
label_horizon: int | pd.Timedelta,
|
|
112
|
+
embargo_size: int | pd.Timedelta | None,
|
|
113
|
+
embargo_pct: float | None,
|
|
114
|
+
group_indices_list: list[NDArray[np.intp]] | None = None,
|
|
115
|
+
) -> NDArray[np.intp]:
|
|
116
|
+
"""Apply purging for multi-asset data with per-asset isolation.
|
|
117
|
+
|
|
118
|
+
This method correctly handles non-contiguous test groups by applying
|
|
119
|
+
purging for each contiguous segment of test data separately per asset.
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
train_indices : ndarray
|
|
124
|
+
Initial training indices.
|
|
125
|
+
test_group_indices : tuple of int
|
|
126
|
+
Indices of groups used for testing.
|
|
127
|
+
group_boundaries : list of tuple
|
|
128
|
+
Boundaries (start, end) for each group.
|
|
129
|
+
n_samples : int
|
|
130
|
+
Total number of samples.
|
|
131
|
+
timestamps : pd.DatetimeIndex, optional
|
|
132
|
+
Timestamps for time-based purging.
|
|
133
|
+
groups_array : ndarray
|
|
134
|
+
Asset labels for each sample.
|
|
135
|
+
label_horizon : int or pd.Timedelta
|
|
136
|
+
Forward-looking period of labels.
|
|
137
|
+
embargo_size : int or pd.Timedelta, optional
|
|
138
|
+
Buffer period after test set.
|
|
139
|
+
embargo_pct : float, optional
|
|
140
|
+
Embargo as percentage of samples.
|
|
141
|
+
group_indices_list : list of ndarray, optional
|
|
142
|
+
Exact indices per group (for session-aligned mode).
|
|
143
|
+
|
|
144
|
+
Returns
|
|
145
|
+
-------
|
|
146
|
+
clean_indices : ndarray
|
|
147
|
+
Training indices after per-asset purging.
|
|
148
|
+
"""
|
|
149
|
+
if len(groups_array) != n_samples:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
f"groups length ({len(groups_array)}) must match number of samples ({n_samples})",
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Prepare test groups data for contiguous segment detection
|
|
155
|
+
test_groups_data = prepare_test_groups_data(
|
|
156
|
+
test_group_indices, group_boundaries, group_indices_list
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Apply purging per asset
|
|
160
|
+
final_train_indices: list[int] = []
|
|
161
|
+
unique_assets = np.unique(groups_array)
|
|
162
|
+
|
|
163
|
+
for asset_id in unique_assets:
|
|
164
|
+
# Process this asset's training data with purging
|
|
165
|
+
asset_train = process_asset_purging(
|
|
166
|
+
asset_id=asset_id,
|
|
167
|
+
groups_array=groups_array,
|
|
168
|
+
train_indices=train_indices,
|
|
169
|
+
test_groups_data=test_groups_data,
|
|
170
|
+
n_samples=n_samples,
|
|
171
|
+
timestamps=timestamps,
|
|
172
|
+
label_horizon=label_horizon,
|
|
173
|
+
embargo_size=embargo_size,
|
|
174
|
+
embargo_pct=embargo_pct,
|
|
175
|
+
group_indices_list=group_indices_list,
|
|
176
|
+
)
|
|
177
|
+
final_train_indices.extend(asset_train)
|
|
178
|
+
|
|
179
|
+
# Sort for deterministic output
|
|
180
|
+
return np.sort(np.array(final_train_indices, dtype=np.intp))
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def prepare_test_groups_data(
|
|
184
|
+
test_group_indices: tuple[int, ...],
|
|
185
|
+
group_boundaries: list[tuple[int, int]],
|
|
186
|
+
group_indices_list: list[NDArray[np.intp]] | None = None,
|
|
187
|
+
) -> list[tuple[int, int, int, NDArray[np.intp] | None]]:
|
|
188
|
+
"""Prepare and sort test groups data for contiguous segment detection.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
test_group_indices : tuple of int
|
|
193
|
+
Which groups are used for testing.
|
|
194
|
+
group_boundaries : list of tuple
|
|
195
|
+
Boundaries (start, end) for each group.
|
|
196
|
+
group_indices_list : list of ndarray, optional
|
|
197
|
+
Exact indices per group (for session-aligned mode).
|
|
198
|
+
|
|
199
|
+
Returns
|
|
200
|
+
-------
|
|
201
|
+
test_groups_data : list of tuple
|
|
202
|
+
Sorted list of (group_idx, start_idx, end_idx, exact_indices).
|
|
203
|
+
In session-aligned mode, exact_indices contains the actual row indices;
|
|
204
|
+
otherwise it's None.
|
|
205
|
+
"""
|
|
206
|
+
test_groups_data: list[tuple[int, int, int, NDArray[np.intp] | None]] = []
|
|
207
|
+
for test_group_idx in test_group_indices:
|
|
208
|
+
test_start_idx, test_end_idx = group_boundaries[test_group_idx]
|
|
209
|
+
exact_indices = (
|
|
210
|
+
group_indices_list[test_group_idx] if group_indices_list is not None else None
|
|
211
|
+
)
|
|
212
|
+
test_groups_data.append((test_group_idx, test_start_idx, test_end_idx, exact_indices))
|
|
213
|
+
|
|
214
|
+
# Sort test groups by start index to identify contiguous segments
|
|
215
|
+
test_groups_data.sort(key=lambda x: x[1])
|
|
216
|
+
return test_groups_data
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def process_asset_purging(
|
|
220
|
+
asset_id: Any,
|
|
221
|
+
groups_array: NDArray[Any],
|
|
222
|
+
train_indices: NDArray[np.intp],
|
|
223
|
+
test_groups_data: list[tuple[int, int, int, NDArray[np.intp] | None]],
|
|
224
|
+
n_samples: int,
|
|
225
|
+
timestamps: pd.DatetimeIndex | None,
|
|
226
|
+
label_horizon: int | pd.Timedelta,
|
|
227
|
+
embargo_size: int | pd.Timedelta | None,
|
|
228
|
+
embargo_pct: float | None,
|
|
229
|
+
group_indices_list: list[NDArray[np.intp]] | None = None,
|
|
230
|
+
) -> list[int]:
|
|
231
|
+
"""Process purging for a single asset across all test segments.
|
|
232
|
+
|
|
233
|
+
Parameters
|
|
234
|
+
----------
|
|
235
|
+
asset_id : any
|
|
236
|
+
Identifier for this asset.
|
|
237
|
+
groups_array : ndarray
|
|
238
|
+
Asset labels for all samples.
|
|
239
|
+
train_indices : ndarray
|
|
240
|
+
Candidate training indices.
|
|
241
|
+
test_groups_data : list of tuple
|
|
242
|
+
Test group information from prepare_test_groups_data.
|
|
243
|
+
n_samples : int
|
|
244
|
+
Total number of samples.
|
|
245
|
+
timestamps : pd.DatetimeIndex, optional
|
|
246
|
+
Timestamps for time-based purging.
|
|
247
|
+
label_horizon : int or pd.Timedelta
|
|
248
|
+
Forward-looking period of labels.
|
|
249
|
+
embargo_size : int or pd.Timedelta, optional
|
|
250
|
+
Buffer period after test set.
|
|
251
|
+
embargo_pct : float, optional
|
|
252
|
+
Embargo as percentage of samples.
|
|
253
|
+
group_indices_list : list of ndarray, optional
|
|
254
|
+
Exact indices per group (for session-aligned mode).
|
|
255
|
+
|
|
256
|
+
Returns
|
|
257
|
+
-------
|
|
258
|
+
clean_indices : list of int
|
|
259
|
+
Training indices for this asset after purging.
|
|
260
|
+
"""
|
|
261
|
+
# Find indices for this asset
|
|
262
|
+
asset_mask = groups_array == asset_id
|
|
263
|
+
asset_indices = np.where(asset_mask)[0]
|
|
264
|
+
|
|
265
|
+
# Get train indices for this asset
|
|
266
|
+
asset_train_indices = np.intersect1d(train_indices, asset_indices)
|
|
267
|
+
|
|
268
|
+
if len(asset_train_indices) == 0:
|
|
269
|
+
return []
|
|
270
|
+
|
|
271
|
+
# Find contiguous segments of test groups for this asset
|
|
272
|
+
contiguous_segments = find_contiguous_segments(
|
|
273
|
+
test_groups_data,
|
|
274
|
+
asset_indices,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# If no test data for this asset, keep all training data
|
|
278
|
+
if not contiguous_segments:
|
|
279
|
+
return asset_train_indices.tolist()
|
|
280
|
+
|
|
281
|
+
# Apply purging for each contiguous segment
|
|
282
|
+
return apply_segment_purging(
|
|
283
|
+
asset_train_indices=asset_train_indices,
|
|
284
|
+
contiguous_segments=contiguous_segments,
|
|
285
|
+
n_samples=n_samples,
|
|
286
|
+
timestamps=timestamps,
|
|
287
|
+
label_horizon=label_horizon,
|
|
288
|
+
embargo_size=embargo_size,
|
|
289
|
+
embargo_pct=embargo_pct,
|
|
290
|
+
group_indices_list=group_indices_list,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def apply_segment_purging(
|
|
295
|
+
asset_train_indices: NDArray[np.intp],
|
|
296
|
+
contiguous_segments: list[list[tuple[int, int, int, NDArray[np.intp]]]],
|
|
297
|
+
n_samples: int,
|
|
298
|
+
timestamps: pd.DatetimeIndex | None,
|
|
299
|
+
label_horizon: int | pd.Timedelta,
|
|
300
|
+
embargo_size: int | pd.Timedelta | None,
|
|
301
|
+
embargo_pct: float | None,
|
|
302
|
+
group_indices_list: list[NDArray[np.intp]] | None = None,
|
|
303
|
+
) -> list[int]:
|
|
304
|
+
"""Apply purging across all contiguous segments for an asset.
|
|
305
|
+
|
|
306
|
+
Uses a set-based approach for tracking remaining indices, which is
|
|
307
|
+
efficient for the iterative purging across segments.
|
|
308
|
+
|
|
309
|
+
Parameters
|
|
310
|
+
----------
|
|
311
|
+
asset_train_indices : ndarray
|
|
312
|
+
Training indices for this asset.
|
|
313
|
+
contiguous_segments : list of list of tuple
|
|
314
|
+
Segments from find_contiguous_segments.
|
|
315
|
+
n_samples : int
|
|
316
|
+
Total number of samples.
|
|
317
|
+
timestamps : pd.DatetimeIndex, optional
|
|
318
|
+
Timestamps for time-based purging.
|
|
319
|
+
label_horizon : int or pd.Timedelta
|
|
320
|
+
Forward-looking period of labels.
|
|
321
|
+
embargo_size : int or pd.Timedelta, optional
|
|
322
|
+
Buffer period after test set.
|
|
323
|
+
embargo_pct : float, optional
|
|
324
|
+
Embargo as percentage of samples.
|
|
325
|
+
group_indices_list : list of ndarray, optional
|
|
326
|
+
Exact indices per group (for session-aligned mode).
|
|
327
|
+
|
|
328
|
+
Returns
|
|
329
|
+
-------
|
|
330
|
+
clean_indices : list of int
|
|
331
|
+
Sorted training indices after purging all segments.
|
|
332
|
+
"""
|
|
333
|
+
remaining_train_indices = set(asset_train_indices)
|
|
334
|
+
|
|
335
|
+
for segment in contiguous_segments:
|
|
336
|
+
if not segment:
|
|
337
|
+
continue
|
|
338
|
+
|
|
339
|
+
# Compute purge window bounds
|
|
340
|
+
if group_indices_list is not None and timestamps is not None:
|
|
341
|
+
# Session-aligned mode: compute timestamp bounds from actual test indices
|
|
342
|
+
segment_test_indices = np.concatenate([item[3] for item in segment])
|
|
343
|
+
window = timestamp_window_from_indices(segment_test_indices, timestamps)
|
|
344
|
+
if window is None:
|
|
345
|
+
# Empty test segment - skip purging for this segment
|
|
346
|
+
continue
|
|
347
|
+
segment_start_time = window.start
|
|
348
|
+
segment_end_time = window.end_exclusive
|
|
349
|
+
else:
|
|
350
|
+
# Standard mode: use boundaries
|
|
351
|
+
segment_start_idx = segment[0][1] # Start of first group in segment
|
|
352
|
+
segment_end_idx = segment[-1][2] # End of last group in segment
|
|
353
|
+
segment_start_time, segment_end_time = convert_indices_to_timestamps(
|
|
354
|
+
segment_start_idx,
|
|
355
|
+
segment_end_idx,
|
|
356
|
+
timestamps,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# Apply purging for this contiguous segment
|
|
360
|
+
remaining_array = np.array(list(remaining_train_indices), dtype=np.intp)
|
|
361
|
+
|
|
362
|
+
if len(remaining_array) == 0:
|
|
363
|
+
break
|
|
364
|
+
|
|
365
|
+
clean_segment_train = apply_purging_and_embargo(
|
|
366
|
+
train_indices=remaining_array,
|
|
367
|
+
test_start=segment_start_time,
|
|
368
|
+
test_end=segment_end_time,
|
|
369
|
+
label_horizon=label_horizon,
|
|
370
|
+
embargo_size=embargo_size,
|
|
371
|
+
embargo_pct=embargo_pct,
|
|
372
|
+
n_samples=n_samples,
|
|
373
|
+
timestamps=timestamps,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# Update remaining indices (remove those that were purged)
|
|
377
|
+
remaining_train_indices = set(clean_segment_train)
|
|
378
|
+
|
|
379
|
+
return sorted(remaining_train_indices)
|