ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,372 @@
1
+ """Core purging and embargo functionality for time-series cross-validation.
2
+
3
+ This module implements the fundamental algorithms for preventing data leakage
4
+ in financial time-series validation through purging (removing training samples
5
+ whose labels overlap with test data) and embargo (adding gaps to account for
6
+ serial correlation).
7
+
8
+ Based on López de Prado (2018) "Advances in Financial Machine Learning".
9
+ """
10
+
11
+ from typing import TYPE_CHECKING, cast
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+
16
+ if TYPE_CHECKING:
17
+ from numpy.typing import NDArray
18
+
19
+
20
+ def calculate_purge_indices(
21
+ n_samples: int | None = None,
22
+ test_start: int | pd.Timestamp | None = None,
23
+ test_end: int | pd.Timestamp | None = None,
24
+ label_horizon: int | pd.Timedelta = 0,
25
+ timestamps: pd.DatetimeIndex | None = None,
26
+ ) -> list[int]:
27
+ """Calculate indices to purge from training set to prevent label leakage.
28
+
29
+ Purging removes training samples whose labels could contain information
30
+ from the test period. If a feature at time t is used to predict a label
31
+ that depends on information up to time t+h, we must remove training
32
+ samples from [test_start - h, test_start).
33
+
34
+ Parameters
35
+ ----------
36
+ n_samples : int, optional
37
+ Total number of samples when using integer indices.
38
+
39
+ test_start : int or pandas.Timestamp
40
+ Start index/time of test period.
41
+
42
+ test_end : int or pandas.Timestamp
43
+ End index/time of test period (exclusive).
44
+
45
+ label_horizon : int or pandas.Timedelta, default=0
46
+ Forward-looking period of labels. For example, if predicting
47
+ 20-day returns, label_horizon=20 (days).
48
+
49
+ timestamps : pandas.DatetimeIndex, optional
50
+ Timestamps for each sample when using time-based indices.
51
+
52
+ Returns:
53
+ -------
54
+ purged_indices : list of int
55
+ Integer positions of samples to remove from training set.
56
+
57
+ Examples:
58
+ --------
59
+ >>> # Integer indices
60
+ >>> purged = calculate_purge_indices(
61
+ ... n_samples=100, test_start=50, test_end=60, label_horizon=5
62
+ ... )
63
+ >>> purged
64
+ [45, 46, 47, 48, 49]
65
+
66
+ >>> # Timestamp indices
67
+ >>> times = pd.date_range("2020-01-01", periods=100, freq="D")
68
+ >>> purged = calculate_purge_indices(
69
+ ... timestamps=times,
70
+ ... test_start=times[50],
71
+ ... test_end=times[60],
72
+ ... label_horizon=pd.Timedelta("5D")
73
+ ... )
74
+ """
75
+ if timestamps is not None:
76
+ # Time-based purging
77
+ if not isinstance(test_start, pd.Timestamp) or not isinstance(
78
+ test_end,
79
+ pd.Timestamp,
80
+ ):
81
+ raise TypeError(
82
+ "test_start and test_end must be Timestamps when using timestamps",
83
+ )
84
+
85
+ # Validate timezone awareness
86
+ if timestamps.tz is None:
87
+ raise ValueError(
88
+ "timestamps must be timezone-aware. Use timestamps.tz_localize('UTC') or timestamps.tz_convert('UTC')"
89
+ )
90
+ if test_start.tz is None:
91
+ raise ValueError(
92
+ "test_start must be timezone-aware when using timestamps. "
93
+ "Use pd.Timestamp(test_start, tz='UTC') or test_start.tz_localize('UTC')"
94
+ )
95
+ if test_end.tz is None:
96
+ raise ValueError(
97
+ "test_end must be timezone-aware when using timestamps. "
98
+ "Use pd.Timestamp(test_end, tz='UTC') or test_end.tz_localize('UTC')"
99
+ )
100
+
101
+ # Convert all to UTC for consistent calculations
102
+ timestamps = timestamps.tz_convert("UTC")
103
+ test_start = test_start.tz_convert("UTC")
104
+ test_end = test_end.tz_convert("UTC")
105
+
106
+ if not isinstance(label_horizon, pd.Timedelta):
107
+ # Convert integer days to Timedelta
108
+ label_horizon = pd.Timedelta(days=label_horizon)
109
+
110
+ # Calculate purge start time
111
+ purge_start_time = test_start - label_horizon
112
+
113
+ # Find indices to purge
114
+ purge_mask = (timestamps >= purge_start_time) & (timestamps < test_start)
115
+ purged_indices = np.where(purge_mask)[0].tolist()
116
+
117
+ else:
118
+ # Integer-based purging
119
+ if n_samples is None:
120
+ raise ValueError("n_samples required for integer-based purging")
121
+
122
+ # In this branch, test_start and label_horizon are integers
123
+ test_start_int = cast(int, test_start)
124
+ label_horizon_int = cast(int, label_horizon)
125
+
126
+ # Calculate purge start
127
+ purge_start = max(0, test_start_int - label_horizon_int)
128
+
129
+ # Indices to purge are [purge_start, test_start)
130
+ purged_indices = list(range(purge_start, test_start_int))
131
+
132
+ return purged_indices
133
+
134
+
135
+ def calculate_embargo_indices(
136
+ n_samples: int | None = None,
137
+ test_start: int | pd.Timestamp | None = None,
138
+ test_end: int | pd.Timestamp | None = None,
139
+ embargo_size: int | pd.Timedelta | None = None,
140
+ embargo_pct: float | None = None,
141
+ timestamps: pd.DatetimeIndex | None = None,
142
+ ) -> list[int]:
143
+ """Calculate indices to embargo after test set to prevent serial correlation.
144
+
145
+ Embargo removes training samples immediately after the test set to account
146
+ for serial correlation in predictions. This prevents the model from learning
147
+ patterns that persist across the test/train boundary.
148
+
149
+ Parameters
150
+ ----------
151
+ n_samples : int, optional
152
+ Total number of samples when using integer indices.
153
+
154
+ test_start : int or pandas.Timestamp
155
+ Start index/time of test period.
156
+
157
+ test_end : int or pandas.Timestamp
158
+ End index/time of test period (exclusive).
159
+
160
+ embargo_size : int or pandas.Timedelta, optional
161
+ Size of embargo period after test set.
162
+
163
+ embargo_pct : float, optional
164
+ Embargo size as percentage of total samples.
165
+ Either embargo_size or embargo_pct should be specified.
166
+
167
+ timestamps : pandas.DatetimeIndex, optional
168
+ Timestamps for each sample when using time-based indices.
169
+
170
+ Returns:
171
+ -------
172
+ embargo_indices : list of int
173
+ Integer positions of samples to embargo.
174
+
175
+ Examples:
176
+ --------
177
+ >>> # Fixed embargo size
178
+ >>> embargoed = calculate_embargo_indices(
179
+ ... n_samples=100, test_start=50, test_end=60, embargo_size=5
180
+ ... )
181
+ >>> embargoed
182
+ [60, 61, 62, 63, 64]
183
+
184
+ >>> # Percentage embargo
185
+ >>> embargoed = calculate_embargo_indices(
186
+ ... n_samples=100, test_start=50, test_end=60, embargo_pct=0.05
187
+ ... )
188
+ """
189
+ if embargo_size is None and embargo_pct is None:
190
+ return []
191
+
192
+ if embargo_size is not None and embargo_pct is not None:
193
+ raise ValueError("Specify either embargo_size or embargo_pct, not both")
194
+
195
+ if timestamps is not None:
196
+ # Time-based embargo
197
+ if not isinstance(test_start, pd.Timestamp) or not isinstance(
198
+ test_end,
199
+ pd.Timestamp,
200
+ ):
201
+ raise TypeError(
202
+ "test_start and test_end must be Timestamps when using timestamps",
203
+ )
204
+
205
+ # Validate timezone awareness
206
+ if timestamps.tz is None:
207
+ raise ValueError(
208
+ "timestamps must be timezone-aware. Use timestamps.tz_localize('UTC') or timestamps.tz_convert('UTC')"
209
+ )
210
+ if test_start.tz is None:
211
+ raise ValueError(
212
+ "test_start must be timezone-aware when using timestamps. "
213
+ "Use pd.Timestamp(test_start, tz='UTC') or test_start.tz_localize('UTC')"
214
+ )
215
+ if test_end.tz is None:
216
+ raise ValueError(
217
+ "test_end must be timezone-aware when using timestamps. "
218
+ "Use pd.Timestamp(test_end, tz='UTC') or test_end.tz_localize('UTC')"
219
+ )
220
+
221
+ # Convert all to UTC for consistent calculations
222
+ timestamps = timestamps.tz_convert("UTC")
223
+ test_start = test_start.tz_convert("UTC")
224
+ test_end = test_end.tz_convert("UTC")
225
+
226
+ # Calculate embargo size if percentage given
227
+ if embargo_pct is not None:
228
+ total_duration = timestamps[-1] - timestamps[0]
229
+ embargo_size = total_duration * embargo_pct
230
+
231
+ if not isinstance(embargo_size, pd.Timedelta):
232
+ # Convert integer days to Timedelta
233
+ embargo_size = pd.Timedelta(days=cast(int, embargo_size))
234
+
235
+ # Calculate embargo end time
236
+ embargo_end_time = test_end + embargo_size
237
+
238
+ # Find indices to embargo
239
+ embargo_mask = (timestamps >= test_end) & (timestamps < embargo_end_time)
240
+ embargo_indices = np.where(embargo_mask)[0].tolist()
241
+
242
+ else:
243
+ # Integer-based embargo
244
+ if n_samples is None:
245
+ raise ValueError("n_samples required for integer-based embargo")
246
+
247
+ # Calculate embargo size if percentage given
248
+ if embargo_pct is not None:
249
+ embargo_size = int(n_samples * embargo_pct)
250
+
251
+ # Calculate embargo end
252
+ # Either embargo_size was provided or calculated from embargo_pct
253
+ assert embargo_size is not None
254
+ # In this branch, test_end and embargo_size are integers
255
+ test_end_int = cast(int, test_end)
256
+ embargo_size_int = cast(int, embargo_size)
257
+ embargo_end = min(n_samples, test_end_int + embargo_size_int)
258
+
259
+ # Indices to embargo are [test_end, embargo_end)
260
+ embargo_indices = list(range(test_end_int, embargo_end))
261
+
262
+ return embargo_indices
263
+
264
+
265
+ def apply_purging_and_embargo(
266
+ train_indices: "NDArray[np.intp]",
267
+ test_start: int | pd.Timestamp,
268
+ test_end: int | pd.Timestamp,
269
+ label_horizon: int | pd.Timedelta = 0,
270
+ embargo_size: int | pd.Timedelta | None = None,
271
+ embargo_pct: float | None = None,
272
+ n_samples: int | None = None,
273
+ timestamps: pd.DatetimeIndex | None = None,
274
+ ) -> "NDArray[np.intp]":
275
+ """Apply both purging and embargo to training indices.
276
+
277
+ This is a convenience function that combines purging and embargo
278
+ to clean a set of training indices, removing any that could lead
279
+ to data leakage or serial correlation issues.
280
+
281
+ Parameters
282
+ ----------
283
+ train_indices : numpy.ndarray
284
+ Initial training indices before purging/embargo.
285
+
286
+ test_start : int or pandas.Timestamp
287
+ Start index/time of test period.
288
+
289
+ test_end : int or pandas.Timestamp
290
+ End index/time of test period (exclusive).
291
+
292
+ label_horizon : int or pandas.Timedelta, default=0
293
+ Forward-looking period of labels.
294
+
295
+ embargo_size : int or pandas.Timedelta, optional
296
+ Size of embargo period after test set.
297
+
298
+ embargo_pct : float, optional
299
+ Embargo size as percentage of total samples.
300
+
301
+ n_samples : int, optional
302
+ Total number of samples (required for integer indices).
303
+
304
+ timestamps : pandas.DatetimeIndex, optional
305
+ Timestamps for each sample when using time-based indices.
306
+
307
+ Returns:
308
+ -------
309
+ clean_indices : numpy.ndarray
310
+ Training indices after removing purged and embargoed samples.
311
+
312
+ Examples:
313
+ --------
314
+ >>> train = np.arange(100)
315
+ >>> clean = apply_purging_and_embargo(
316
+ ... train_indices=train,
317
+ ... test_start=50,
318
+ ... test_end=60,
319
+ ... label_horizon=5,
320
+ ... embargo_size=5,
321
+ ... n_samples=100
322
+ ... )
323
+ >>> # Removes [45,50) for purging and [60,65) for embargo
324
+ >>> len(clean)
325
+ 85
326
+ """
327
+ # Calculate indices to remove - convert to numpy arrays immediately
328
+ purged_list = calculate_purge_indices(
329
+ n_samples=n_samples,
330
+ test_start=test_start,
331
+ test_end=test_end,
332
+ label_horizon=label_horizon,
333
+ timestamps=timestamps,
334
+ )
335
+ purged_arr = np.asarray(purged_list, dtype=np.intp)
336
+
337
+ embargoed_list = calculate_embargo_indices(
338
+ n_samples=n_samples,
339
+ test_start=test_start,
340
+ test_end=test_end,
341
+ embargo_size=embargo_size,
342
+ embargo_pct=embargo_pct,
343
+ timestamps=timestamps,
344
+ )
345
+ embargoed_arr = np.asarray(embargoed_list, dtype=np.intp)
346
+
347
+ # Also remove test indices themselves
348
+ if timestamps is not None:
349
+ # Use searchsorted for more robust boundary handling
350
+ test_start_idx = timestamps.searchsorted(test_start, side="left")
351
+ test_end_idx = timestamps.searchsorted(test_end, side="left")
352
+ test_arr = np.arange(test_start_idx, test_end_idx, dtype=np.intp)
353
+ else:
354
+ # When timestamps is None, test_start/test_end are integer indices
355
+ # Accept both Python int and numpy integer types
356
+ assert isinstance(test_start, int | np.integer), f"Expected int, got {type(test_start)}"
357
+ assert isinstance(test_end, int | np.integer), f"Expected int, got {type(test_end)}"
358
+ test_arr = np.arange(int(test_start), int(test_end), dtype=np.intp)
359
+
360
+ # Combine all indices to remove using numpy (faster than Python sets)
361
+ # Filter out empty arrays before concatenating
362
+ arrays_to_concat = [arr for arr in (purged_arr, embargoed_arr, test_arr) if len(arr) > 0]
363
+ if arrays_to_concat:
364
+ remove_indices = np.unique(np.concatenate(arrays_to_concat))
365
+ else:
366
+ remove_indices = np.array([], dtype=np.intp)
367
+
368
+ # Keep only indices not in remove set
369
+ clean_mask = ~np.isin(train_indices, remove_indices)
370
+ clean_indices = train_indices[clean_mask]
371
+
372
+ return clean_indices