ml4t-diagnostic 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4t/diagnostic/AGENT.md +25 -0
- ml4t/diagnostic/__init__.py +166 -0
- ml4t/diagnostic/backends/__init__.py +10 -0
- ml4t/diagnostic/backends/adapter.py +192 -0
- ml4t/diagnostic/backends/polars_backend.py +899 -0
- ml4t/diagnostic/caching/__init__.py +40 -0
- ml4t/diagnostic/caching/cache.py +331 -0
- ml4t/diagnostic/caching/decorators.py +131 -0
- ml4t/diagnostic/caching/smart_cache.py +339 -0
- ml4t/diagnostic/config/AGENT.md +24 -0
- ml4t/diagnostic/config/README.md +267 -0
- ml4t/diagnostic/config/__init__.py +219 -0
- ml4t/diagnostic/config/barrier_config.py +277 -0
- ml4t/diagnostic/config/base.py +301 -0
- ml4t/diagnostic/config/event_config.py +148 -0
- ml4t/diagnostic/config/feature_config.py +404 -0
- ml4t/diagnostic/config/multi_signal_config.py +55 -0
- ml4t/diagnostic/config/portfolio_config.py +215 -0
- ml4t/diagnostic/config/report_config.py +391 -0
- ml4t/diagnostic/config/sharpe_config.py +202 -0
- ml4t/diagnostic/config/signal_config.py +206 -0
- ml4t/diagnostic/config/trade_analysis_config.py +310 -0
- ml4t/diagnostic/config/validation.py +279 -0
- ml4t/diagnostic/core/__init__.py +29 -0
- ml4t/diagnostic/core/numba_utils.py +315 -0
- ml4t/diagnostic/core/purging.py +372 -0
- ml4t/diagnostic/core/sampling.py +471 -0
- ml4t/diagnostic/errors/__init__.py +205 -0
- ml4t/diagnostic/evaluation/AGENT.md +26 -0
- ml4t/diagnostic/evaluation/__init__.py +437 -0
- ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
- ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
- ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
- ml4t/diagnostic/evaluation/dashboard.py +715 -0
- ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
- ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
- ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
- ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
- ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
- ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
- ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
- ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
- ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
- ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
- ml4t/diagnostic/evaluation/event_analysis.py +647 -0
- ml4t/diagnostic/evaluation/excursion.py +390 -0
- ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
- ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
- ml4t/diagnostic/evaluation/framework.py +935 -0
- ml4t/diagnostic/evaluation/metric_registry.py +255 -0
- ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
- ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
- ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
- ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
- ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
- ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
- ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
- ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
- ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
- ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
- ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
- ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
- ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
- ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
- ml4t/diagnostic/evaluation/multi_signal.py +550 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
- ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
- ml4t/diagnostic/evaluation/report_generation.py +824 -0
- ml4t/diagnostic/evaluation/signal_selector.py +452 -0
- ml4t/diagnostic/evaluation/stat_registry.py +139 -0
- ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
- ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
- ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
- ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
- ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
- ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
- ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
- ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
- ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
- ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
- ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
- ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
- ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
- ml4t/diagnostic/evaluation/stats/moments.py +164 -0
- ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
- ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
- ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
- ml4t/diagnostic/evaluation/themes.py +330 -0
- ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
- ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
- ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
- ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
- ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
- ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
- ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
- ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
- ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
- ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
- ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
- ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
- ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
- ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
- ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
- ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
- ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
- ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
- ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
- ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
- ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
- ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
- ml4t/diagnostic/evaluation/validated_cv.py +535 -0
- ml4t/diagnostic/evaluation/visualization.py +1050 -0
- ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
- ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
- ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
- ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
- ml4t/diagnostic/integration/__init__.py +48 -0
- ml4t/diagnostic/integration/backtest_contract.py +671 -0
- ml4t/diagnostic/integration/data_contract.py +316 -0
- ml4t/diagnostic/integration/engineer_contract.py +226 -0
- ml4t/diagnostic/logging/__init__.py +77 -0
- ml4t/diagnostic/logging/logger.py +245 -0
- ml4t/diagnostic/logging/performance.py +234 -0
- ml4t/diagnostic/logging/progress.py +234 -0
- ml4t/diagnostic/logging/wandb.py +412 -0
- ml4t/diagnostic/metrics/__init__.py +9 -0
- ml4t/diagnostic/metrics/percentiles.py +128 -0
- ml4t/diagnostic/py.typed +1 -0
- ml4t/diagnostic/reporting/__init__.py +43 -0
- ml4t/diagnostic/reporting/base.py +130 -0
- ml4t/diagnostic/reporting/html_renderer.py +275 -0
- ml4t/diagnostic/reporting/json_renderer.py +51 -0
- ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
- ml4t/diagnostic/results/AGENT.md +24 -0
- ml4t/diagnostic/results/__init__.py +105 -0
- ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
- ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
- ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
- ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
- ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
- ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
- ml4t/diagnostic/results/barrier_results/validation.py +38 -0
- ml4t/diagnostic/results/base.py +177 -0
- ml4t/diagnostic/results/event_results.py +349 -0
- ml4t/diagnostic/results/feature_results.py +787 -0
- ml4t/diagnostic/results/multi_signal_results.py +431 -0
- ml4t/diagnostic/results/portfolio_results.py +281 -0
- ml4t/diagnostic/results/sharpe_results.py +448 -0
- ml4t/diagnostic/results/signal_results/__init__.py +74 -0
- ml4t/diagnostic/results/signal_results/ic.py +581 -0
- ml4t/diagnostic/results/signal_results/irtc.py +110 -0
- ml4t/diagnostic/results/signal_results/quantile.py +392 -0
- ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
- ml4t/diagnostic/results/signal_results/turnover.py +213 -0
- ml4t/diagnostic/results/signal_results/validation.py +147 -0
- ml4t/diagnostic/signal/AGENT.md +17 -0
- ml4t/diagnostic/signal/__init__.py +69 -0
- ml4t/diagnostic/signal/_report.py +152 -0
- ml4t/diagnostic/signal/_utils.py +261 -0
- ml4t/diagnostic/signal/core.py +275 -0
- ml4t/diagnostic/signal/quantile.py +148 -0
- ml4t/diagnostic/signal/result.py +214 -0
- ml4t/diagnostic/signal/signal_ic.py +129 -0
- ml4t/diagnostic/signal/turnover.py +182 -0
- ml4t/diagnostic/splitters/AGENT.md +19 -0
- ml4t/diagnostic/splitters/__init__.py +36 -0
- ml4t/diagnostic/splitters/base.py +501 -0
- ml4t/diagnostic/splitters/calendar.py +421 -0
- ml4t/diagnostic/splitters/calendar_config.py +91 -0
- ml4t/diagnostic/splitters/combinatorial.py +1064 -0
- ml4t/diagnostic/splitters/config.py +322 -0
- ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
- ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
- ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
- ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
- ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
- ml4t/diagnostic/splitters/group_isolation.py +329 -0
- ml4t/diagnostic/splitters/persistence.py +316 -0
- ml4t/diagnostic/splitters/utils.py +207 -0
- ml4t/diagnostic/splitters/walk_forward.py +757 -0
- ml4t/diagnostic/utils/__init__.py +42 -0
- ml4t/diagnostic/utils/config.py +542 -0
- ml4t/diagnostic/utils/dependencies.py +318 -0
- ml4t/diagnostic/utils/sessions.py +127 -0
- ml4t/diagnostic/validation/__init__.py +54 -0
- ml4t/diagnostic/validation/dataframe.py +274 -0
- ml4t/diagnostic/validation/returns.py +280 -0
- ml4t/diagnostic/validation/timeseries.py +299 -0
- ml4t/diagnostic/visualization/AGENT.md +19 -0
- ml4t/diagnostic/visualization/__init__.py +223 -0
- ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
- ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
- ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
- ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
- ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
- ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
- ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
- ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
- ml4t/diagnostic/visualization/barrier_plots.py +782 -0
- ml4t/diagnostic/visualization/core.py +1060 -0
- ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
- ml4t/diagnostic/visualization/dashboards/base.py +582 -0
- ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
- ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
- ml4t/diagnostic/visualization/dashboards.py +43 -0
- ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
- ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
- ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
- ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
- ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
- ml4t/diagnostic/visualization/feature_plots.py +888 -0
- ml4t/diagnostic/visualization/interaction_plots.py +618 -0
- ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
- ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
- ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
- ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
- ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
- ml4t/diagnostic/visualization/report_generation.py +1343 -0
- ml4t/diagnostic/visualization/signal/__init__.py +103 -0
- ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
- ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
- ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
- ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
- ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
- ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
- ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
- ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
- ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
- ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
- ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
- ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
"""Core purging and embargo functionality for time-series cross-validation.
|
|
2
|
+
|
|
3
|
+
This module implements the fundamental algorithms for preventing data leakage
|
|
4
|
+
in financial time-series validation through purging (removing training samples
|
|
5
|
+
whose labels overlap with test data) and embargo (adding gaps to account for
|
|
6
|
+
serial correlation).
|
|
7
|
+
|
|
8
|
+
Based on López de Prado (2018) "Advances in Financial Machine Learning".
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import TYPE_CHECKING, cast
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from numpy.typing import NDArray
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def calculate_purge_indices(
|
|
21
|
+
n_samples: int | None = None,
|
|
22
|
+
test_start: int | pd.Timestamp | None = None,
|
|
23
|
+
test_end: int | pd.Timestamp | None = None,
|
|
24
|
+
label_horizon: int | pd.Timedelta = 0,
|
|
25
|
+
timestamps: pd.DatetimeIndex | None = None,
|
|
26
|
+
) -> list[int]:
|
|
27
|
+
"""Calculate indices to purge from training set to prevent label leakage.
|
|
28
|
+
|
|
29
|
+
Purging removes training samples whose labels could contain information
|
|
30
|
+
from the test period. If a feature at time t is used to predict a label
|
|
31
|
+
that depends on information up to time t+h, we must remove training
|
|
32
|
+
samples from [test_start - h, test_start).
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
n_samples : int, optional
|
|
37
|
+
Total number of samples when using integer indices.
|
|
38
|
+
|
|
39
|
+
test_start : int or pandas.Timestamp
|
|
40
|
+
Start index/time of test period.
|
|
41
|
+
|
|
42
|
+
test_end : int or pandas.Timestamp
|
|
43
|
+
End index/time of test period (exclusive).
|
|
44
|
+
|
|
45
|
+
label_horizon : int or pandas.Timedelta, default=0
|
|
46
|
+
Forward-looking period of labels. For example, if predicting
|
|
47
|
+
20-day returns, label_horizon=20 (days).
|
|
48
|
+
|
|
49
|
+
timestamps : pandas.DatetimeIndex, optional
|
|
50
|
+
Timestamps for each sample when using time-based indices.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
-------
|
|
54
|
+
purged_indices : list of int
|
|
55
|
+
Integer positions of samples to remove from training set.
|
|
56
|
+
|
|
57
|
+
Examples:
|
|
58
|
+
--------
|
|
59
|
+
>>> # Integer indices
|
|
60
|
+
>>> purged = calculate_purge_indices(
|
|
61
|
+
... n_samples=100, test_start=50, test_end=60, label_horizon=5
|
|
62
|
+
... )
|
|
63
|
+
>>> purged
|
|
64
|
+
[45, 46, 47, 48, 49]
|
|
65
|
+
|
|
66
|
+
>>> # Timestamp indices
|
|
67
|
+
>>> times = pd.date_range("2020-01-01", periods=100, freq="D")
|
|
68
|
+
>>> purged = calculate_purge_indices(
|
|
69
|
+
... timestamps=times,
|
|
70
|
+
... test_start=times[50],
|
|
71
|
+
... test_end=times[60],
|
|
72
|
+
... label_horizon=pd.Timedelta("5D")
|
|
73
|
+
... )
|
|
74
|
+
"""
|
|
75
|
+
if timestamps is not None:
|
|
76
|
+
# Time-based purging
|
|
77
|
+
if not isinstance(test_start, pd.Timestamp) or not isinstance(
|
|
78
|
+
test_end,
|
|
79
|
+
pd.Timestamp,
|
|
80
|
+
):
|
|
81
|
+
raise TypeError(
|
|
82
|
+
"test_start and test_end must be Timestamps when using timestamps",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Validate timezone awareness
|
|
86
|
+
if timestamps.tz is None:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
"timestamps must be timezone-aware. Use timestamps.tz_localize('UTC') or timestamps.tz_convert('UTC')"
|
|
89
|
+
)
|
|
90
|
+
if test_start.tz is None:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"test_start must be timezone-aware when using timestamps. "
|
|
93
|
+
"Use pd.Timestamp(test_start, tz='UTC') or test_start.tz_localize('UTC')"
|
|
94
|
+
)
|
|
95
|
+
if test_end.tz is None:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"test_end must be timezone-aware when using timestamps. "
|
|
98
|
+
"Use pd.Timestamp(test_end, tz='UTC') or test_end.tz_localize('UTC')"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Convert all to UTC for consistent calculations
|
|
102
|
+
timestamps = timestamps.tz_convert("UTC")
|
|
103
|
+
test_start = test_start.tz_convert("UTC")
|
|
104
|
+
test_end = test_end.tz_convert("UTC")
|
|
105
|
+
|
|
106
|
+
if not isinstance(label_horizon, pd.Timedelta):
|
|
107
|
+
# Convert integer days to Timedelta
|
|
108
|
+
label_horizon = pd.Timedelta(days=label_horizon)
|
|
109
|
+
|
|
110
|
+
# Calculate purge start time
|
|
111
|
+
purge_start_time = test_start - label_horizon
|
|
112
|
+
|
|
113
|
+
# Find indices to purge
|
|
114
|
+
purge_mask = (timestamps >= purge_start_time) & (timestamps < test_start)
|
|
115
|
+
purged_indices = np.where(purge_mask)[0].tolist()
|
|
116
|
+
|
|
117
|
+
else:
|
|
118
|
+
# Integer-based purging
|
|
119
|
+
if n_samples is None:
|
|
120
|
+
raise ValueError("n_samples required for integer-based purging")
|
|
121
|
+
|
|
122
|
+
# In this branch, test_start and label_horizon are integers
|
|
123
|
+
test_start_int = cast(int, test_start)
|
|
124
|
+
label_horizon_int = cast(int, label_horizon)
|
|
125
|
+
|
|
126
|
+
# Calculate purge start
|
|
127
|
+
purge_start = max(0, test_start_int - label_horizon_int)
|
|
128
|
+
|
|
129
|
+
# Indices to purge are [purge_start, test_start)
|
|
130
|
+
purged_indices = list(range(purge_start, test_start_int))
|
|
131
|
+
|
|
132
|
+
return purged_indices
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def calculate_embargo_indices(
|
|
136
|
+
n_samples: int | None = None,
|
|
137
|
+
test_start: int | pd.Timestamp | None = None,
|
|
138
|
+
test_end: int | pd.Timestamp | None = None,
|
|
139
|
+
embargo_size: int | pd.Timedelta | None = None,
|
|
140
|
+
embargo_pct: float | None = None,
|
|
141
|
+
timestamps: pd.DatetimeIndex | None = None,
|
|
142
|
+
) -> list[int]:
|
|
143
|
+
"""Calculate indices to embargo after test set to prevent serial correlation.
|
|
144
|
+
|
|
145
|
+
Embargo removes training samples immediately after the test set to account
|
|
146
|
+
for serial correlation in predictions. This prevents the model from learning
|
|
147
|
+
patterns that persist across the test/train boundary.
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
n_samples : int, optional
|
|
152
|
+
Total number of samples when using integer indices.
|
|
153
|
+
|
|
154
|
+
test_start : int or pandas.Timestamp
|
|
155
|
+
Start index/time of test period.
|
|
156
|
+
|
|
157
|
+
test_end : int or pandas.Timestamp
|
|
158
|
+
End index/time of test period (exclusive).
|
|
159
|
+
|
|
160
|
+
embargo_size : int or pandas.Timedelta, optional
|
|
161
|
+
Size of embargo period after test set.
|
|
162
|
+
|
|
163
|
+
embargo_pct : float, optional
|
|
164
|
+
Embargo size as percentage of total samples.
|
|
165
|
+
Either embargo_size or embargo_pct should be specified.
|
|
166
|
+
|
|
167
|
+
timestamps : pandas.DatetimeIndex, optional
|
|
168
|
+
Timestamps for each sample when using time-based indices.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
-------
|
|
172
|
+
embargo_indices : list of int
|
|
173
|
+
Integer positions of samples to embargo.
|
|
174
|
+
|
|
175
|
+
Examples:
|
|
176
|
+
--------
|
|
177
|
+
>>> # Fixed embargo size
|
|
178
|
+
>>> embargoed = calculate_embargo_indices(
|
|
179
|
+
... n_samples=100, test_start=50, test_end=60, embargo_size=5
|
|
180
|
+
... )
|
|
181
|
+
>>> embargoed
|
|
182
|
+
[60, 61, 62, 63, 64]
|
|
183
|
+
|
|
184
|
+
>>> # Percentage embargo
|
|
185
|
+
>>> embargoed = calculate_embargo_indices(
|
|
186
|
+
... n_samples=100, test_start=50, test_end=60, embargo_pct=0.05
|
|
187
|
+
... )
|
|
188
|
+
"""
|
|
189
|
+
if embargo_size is None and embargo_pct is None:
|
|
190
|
+
return []
|
|
191
|
+
|
|
192
|
+
if embargo_size is not None and embargo_pct is not None:
|
|
193
|
+
raise ValueError("Specify either embargo_size or embargo_pct, not both")
|
|
194
|
+
|
|
195
|
+
if timestamps is not None:
|
|
196
|
+
# Time-based embargo
|
|
197
|
+
if not isinstance(test_start, pd.Timestamp) or not isinstance(
|
|
198
|
+
test_end,
|
|
199
|
+
pd.Timestamp,
|
|
200
|
+
):
|
|
201
|
+
raise TypeError(
|
|
202
|
+
"test_start and test_end must be Timestamps when using timestamps",
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Validate timezone awareness
|
|
206
|
+
if timestamps.tz is None:
|
|
207
|
+
raise ValueError(
|
|
208
|
+
"timestamps must be timezone-aware. Use timestamps.tz_localize('UTC') or timestamps.tz_convert('UTC')"
|
|
209
|
+
)
|
|
210
|
+
if test_start.tz is None:
|
|
211
|
+
raise ValueError(
|
|
212
|
+
"test_start must be timezone-aware when using timestamps. "
|
|
213
|
+
"Use pd.Timestamp(test_start, tz='UTC') or test_start.tz_localize('UTC')"
|
|
214
|
+
)
|
|
215
|
+
if test_end.tz is None:
|
|
216
|
+
raise ValueError(
|
|
217
|
+
"test_end must be timezone-aware when using timestamps. "
|
|
218
|
+
"Use pd.Timestamp(test_end, tz='UTC') or test_end.tz_localize('UTC')"
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Convert all to UTC for consistent calculations
|
|
222
|
+
timestamps = timestamps.tz_convert("UTC")
|
|
223
|
+
test_start = test_start.tz_convert("UTC")
|
|
224
|
+
test_end = test_end.tz_convert("UTC")
|
|
225
|
+
|
|
226
|
+
# Calculate embargo size if percentage given
|
|
227
|
+
if embargo_pct is not None:
|
|
228
|
+
total_duration = timestamps[-1] - timestamps[0]
|
|
229
|
+
embargo_size = total_duration * embargo_pct
|
|
230
|
+
|
|
231
|
+
if not isinstance(embargo_size, pd.Timedelta):
|
|
232
|
+
# Convert integer days to Timedelta
|
|
233
|
+
embargo_size = pd.Timedelta(days=cast(int, embargo_size))
|
|
234
|
+
|
|
235
|
+
# Calculate embargo end time
|
|
236
|
+
embargo_end_time = test_end + embargo_size
|
|
237
|
+
|
|
238
|
+
# Find indices to embargo
|
|
239
|
+
embargo_mask = (timestamps >= test_end) & (timestamps < embargo_end_time)
|
|
240
|
+
embargo_indices = np.where(embargo_mask)[0].tolist()
|
|
241
|
+
|
|
242
|
+
else:
|
|
243
|
+
# Integer-based embargo
|
|
244
|
+
if n_samples is None:
|
|
245
|
+
raise ValueError("n_samples required for integer-based embargo")
|
|
246
|
+
|
|
247
|
+
# Calculate embargo size if percentage given
|
|
248
|
+
if embargo_pct is not None:
|
|
249
|
+
embargo_size = int(n_samples * embargo_pct)
|
|
250
|
+
|
|
251
|
+
# Calculate embargo end
|
|
252
|
+
# Either embargo_size was provided or calculated from embargo_pct
|
|
253
|
+
assert embargo_size is not None
|
|
254
|
+
# In this branch, test_end and embargo_size are integers
|
|
255
|
+
test_end_int = cast(int, test_end)
|
|
256
|
+
embargo_size_int = cast(int, embargo_size)
|
|
257
|
+
embargo_end = min(n_samples, test_end_int + embargo_size_int)
|
|
258
|
+
|
|
259
|
+
# Indices to embargo are [test_end, embargo_end)
|
|
260
|
+
embargo_indices = list(range(test_end_int, embargo_end))
|
|
261
|
+
|
|
262
|
+
return embargo_indices
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def apply_purging_and_embargo(
|
|
266
|
+
train_indices: "NDArray[np.intp]",
|
|
267
|
+
test_start: int | pd.Timestamp,
|
|
268
|
+
test_end: int | pd.Timestamp,
|
|
269
|
+
label_horizon: int | pd.Timedelta = 0,
|
|
270
|
+
embargo_size: int | pd.Timedelta | None = None,
|
|
271
|
+
embargo_pct: float | None = None,
|
|
272
|
+
n_samples: int | None = None,
|
|
273
|
+
timestamps: pd.DatetimeIndex | None = None,
|
|
274
|
+
) -> "NDArray[np.intp]":
|
|
275
|
+
"""Apply both purging and embargo to training indices.
|
|
276
|
+
|
|
277
|
+
This is a convenience function that combines purging and embargo
|
|
278
|
+
to clean a set of training indices, removing any that could lead
|
|
279
|
+
to data leakage or serial correlation issues.
|
|
280
|
+
|
|
281
|
+
Parameters
|
|
282
|
+
----------
|
|
283
|
+
train_indices : numpy.ndarray
|
|
284
|
+
Initial training indices before purging/embargo.
|
|
285
|
+
|
|
286
|
+
test_start : int or pandas.Timestamp
|
|
287
|
+
Start index/time of test period.
|
|
288
|
+
|
|
289
|
+
test_end : int or pandas.Timestamp
|
|
290
|
+
End index/time of test period (exclusive).
|
|
291
|
+
|
|
292
|
+
label_horizon : int or pandas.Timedelta, default=0
|
|
293
|
+
Forward-looking period of labels.
|
|
294
|
+
|
|
295
|
+
embargo_size : int or pandas.Timedelta, optional
|
|
296
|
+
Size of embargo period after test set.
|
|
297
|
+
|
|
298
|
+
embargo_pct : float, optional
|
|
299
|
+
Embargo size as percentage of total samples.
|
|
300
|
+
|
|
301
|
+
n_samples : int, optional
|
|
302
|
+
Total number of samples (required for integer indices).
|
|
303
|
+
|
|
304
|
+
timestamps : pandas.DatetimeIndex, optional
|
|
305
|
+
Timestamps for each sample when using time-based indices.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
-------
|
|
309
|
+
clean_indices : numpy.ndarray
|
|
310
|
+
Training indices after removing purged and embargoed samples.
|
|
311
|
+
|
|
312
|
+
Examples:
|
|
313
|
+
--------
|
|
314
|
+
>>> train = np.arange(100)
|
|
315
|
+
>>> clean = apply_purging_and_embargo(
|
|
316
|
+
... train_indices=train,
|
|
317
|
+
... test_start=50,
|
|
318
|
+
... test_end=60,
|
|
319
|
+
... label_horizon=5,
|
|
320
|
+
... embargo_size=5,
|
|
321
|
+
... n_samples=100
|
|
322
|
+
... )
|
|
323
|
+
>>> # Removes [45,50) for purging and [60,65) for embargo
|
|
324
|
+
>>> len(clean)
|
|
325
|
+
85
|
|
326
|
+
"""
|
|
327
|
+
# Calculate indices to remove - convert to numpy arrays immediately
|
|
328
|
+
purged_list = calculate_purge_indices(
|
|
329
|
+
n_samples=n_samples,
|
|
330
|
+
test_start=test_start,
|
|
331
|
+
test_end=test_end,
|
|
332
|
+
label_horizon=label_horizon,
|
|
333
|
+
timestamps=timestamps,
|
|
334
|
+
)
|
|
335
|
+
purged_arr = np.asarray(purged_list, dtype=np.intp)
|
|
336
|
+
|
|
337
|
+
embargoed_list = calculate_embargo_indices(
|
|
338
|
+
n_samples=n_samples,
|
|
339
|
+
test_start=test_start,
|
|
340
|
+
test_end=test_end,
|
|
341
|
+
embargo_size=embargo_size,
|
|
342
|
+
embargo_pct=embargo_pct,
|
|
343
|
+
timestamps=timestamps,
|
|
344
|
+
)
|
|
345
|
+
embargoed_arr = np.asarray(embargoed_list, dtype=np.intp)
|
|
346
|
+
|
|
347
|
+
# Also remove test indices themselves
|
|
348
|
+
if timestamps is not None:
|
|
349
|
+
# Use searchsorted for more robust boundary handling
|
|
350
|
+
test_start_idx = timestamps.searchsorted(test_start, side="left")
|
|
351
|
+
test_end_idx = timestamps.searchsorted(test_end, side="left")
|
|
352
|
+
test_arr = np.arange(test_start_idx, test_end_idx, dtype=np.intp)
|
|
353
|
+
else:
|
|
354
|
+
# When timestamps is None, test_start/test_end are integer indices
|
|
355
|
+
# Accept both Python int and numpy integer types
|
|
356
|
+
assert isinstance(test_start, int | np.integer), f"Expected int, got {type(test_start)}"
|
|
357
|
+
assert isinstance(test_end, int | np.integer), f"Expected int, got {type(test_end)}"
|
|
358
|
+
test_arr = np.arange(int(test_start), int(test_end), dtype=np.intp)
|
|
359
|
+
|
|
360
|
+
# Combine all indices to remove using numpy (faster than Python sets)
|
|
361
|
+
# Filter out empty arrays before concatenating
|
|
362
|
+
arrays_to_concat = [arr for arr in (purged_arr, embargoed_arr, test_arr) if len(arr) > 0]
|
|
363
|
+
if arrays_to_concat:
|
|
364
|
+
remove_indices = np.unique(np.concatenate(arrays_to_concat))
|
|
365
|
+
else:
|
|
366
|
+
remove_indices = np.array([], dtype=np.intp)
|
|
367
|
+
|
|
368
|
+
# Keep only indices not in remove set
|
|
369
|
+
clean_mask = ~np.isin(train_indices, remove_indices)
|
|
370
|
+
clean_indices = train_indices[clean_mask]
|
|
371
|
+
|
|
372
|
+
return clean_indices
|