ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,501 @@
1
+ """Base class for all time-series cross-validation splitters.
2
+
3
+ This module defines the abstract base class that all ml4t-diagnostic splitters inherit from,
4
+ ensuring compatibility with scikit-learn's cross-validation framework while adding
5
+ support for time-series specific features like purging and embargo.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from collections.abc import Generator
10
+ from typing import TYPE_CHECKING, Any, Union, cast
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+ import polars as pl
15
+
16
+ if TYPE_CHECKING:
17
+ from numpy.typing import NDArray
18
+
19
+
20
+ class BaseSplitter(ABC):
21
+ """Abstract base class for all ml4t-diagnostic time-series splitters.
22
+
23
+ This class defines the interface that all splitters must implement to ensure
24
+ compatibility with scikit-learn's model selection tools while providing
25
+ additional functionality for financial time-series validation.
26
+
27
+ All splitters should support purging (removing training data that could leak
28
+ information into test data) and embargo (adding gaps between train and test
29
+ sets to account for serial correlation).
30
+
31
+ Session-Aware Splitting
32
+ -----------------------
33
+ Splitters can optionally align fold boundaries to trading session boundaries
34
+ by setting ``align_to_sessions=True``. This requires the data to have a
35
+ session column (default: 'session_date') that identifies trading sessions.
36
+
37
+ Trading sessions are atomic units that should never be split across train/test
38
+ folds. For intraday data (e.g., CME futures with Sunday 5pm - Friday 4pm sessions),
39
+ this prevents subtle lookahead bias from mid-session splits.
40
+
41
+ **Integration with qdata library:**
42
+
43
+ The session column should be added using the ``qdata`` library's session
44
+ assignment functionality::
45
+
46
+ from qdata import DataManager
47
+
48
+ manager = DataManager()
49
+ df = manager.load(symbol="BTC", exchange="CME", calendar="CME_Globex_Crypto")
50
+ # df now has 'session_date' column automatically assigned
51
+
52
+ Or manually using SessionAssigner::
53
+
54
+ from ml4t.data.sessions import SessionAssigner
55
+
56
+ assigner = SessionAssigner.from_exchange('CME')
57
+ df_with_sessions = assigner.assign_sessions(df)
58
+
59
+ Then use with ml4t-diagnostic splitters::
60
+
61
+ from ml4t.diagnostic.splitters import PurgedWalkForwardCV
62
+
63
+ cv = PurgedWalkForwardCV(
64
+ n_splits=5,
65
+ align_to_sessions=True, # Align folds to session boundaries
66
+ session_col='session_date'
67
+ )
68
+
69
+ for train_idx, test_idx in cv.split(df_with_sessions):
70
+ # Fold boundaries respect session boundaries
71
+ pass
72
+ """
73
+
74
+ @abstractmethod
75
+ def split(
76
+ self,
77
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
78
+ y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
79
+ groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
80
+ ) -> Generator[tuple["NDArray[np.intp]", "NDArray[np.intp]"], None, None]:
81
+ """Generate indices to split data into training and test sets.
82
+
83
+ Parameters
84
+ ----------
85
+ X : polars.DataFrame, pandas.DataFrame, or numpy.ndarray
86
+ Training data with shape (n_samples, n_features).
87
+
88
+ y : polars.Series, pandas.Series, numpy.ndarray, or None, default=None
89
+ Target variable with shape (n_samples,). Always ignored but kept
90
+ for scikit-learn compatibility.
91
+
92
+ groups : polars.Series, pandas.Series, numpy.ndarray, or None, default=None
93
+ Group labels for samples, used for multi-asset splitting.
94
+ Shape (n_samples,).
95
+
96
+ Yields:
97
+ ------
98
+ train : numpy.ndarray
99
+ The training set indices for that split.
100
+
101
+ test : numpy.ndarray
102
+ The testing set indices for that split.
103
+
104
+ Notes:
105
+ -----
106
+ The indices returned are integer positions, not labels or timestamps.
107
+ This ensures compatibility with numpy array indexing and scikit-learn.
108
+ """
109
+
110
+ def get_n_splits(
111
+ self,
112
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"] | None = None,
113
+ y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
114
+ groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
115
+ ) -> int:
116
+ """Return the number of splitting iterations in the cross-validator.
117
+
118
+ Parameters
119
+ ----------
120
+ X : polars.DataFrame, pandas.DataFrame, numpy.ndarray, or None, default=None
121
+ Training data. Some splitters may use properties of X to determine
122
+ the number of splits.
123
+
124
+ y : polars.Series, pandas.Series, numpy.ndarray, or None, default=None
125
+ Always ignored, exists for compatibility.
126
+
127
+ groups : polars.Series, pandas.Series, numpy.ndarray, or None, default=None
128
+ Group labels. Some splitters may use this to determine splits.
129
+
130
+ Returns:
131
+ -------
132
+ n_splits : int
133
+ The number of splitting iterations.
134
+
135
+ Notes:
136
+ -----
137
+ Most splitters can determine the number of splits from their parameters
138
+ alone, but some (like GroupKFold variants) may need to inspect the data.
139
+ """
140
+ raise NotImplementedError(
141
+ f"{self.__class__.__name__} must implement get_n_splits()",
142
+ )
143
+
144
+ def _get_n_samples(
145
+ self,
146
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
147
+ ) -> int:
148
+ """Get the number of samples in X regardless of type.
149
+
150
+ Parameters
151
+ ----------
152
+ X : polars.DataFrame, pandas.DataFrame, or numpy.ndarray
153
+ The data to get the sample count from.
154
+
155
+ Returns:
156
+ -------
157
+ n_samples : int
158
+ The number of samples (rows) in X.
159
+ """
160
+ if isinstance(X, pl.DataFrame):
161
+ return X.height
162
+ if isinstance(X, pl.LazyFrame):
163
+ # LazyFrame doesn't have height, need to collect first
164
+ return X.collect().height
165
+ if isinstance(X, pd.DataFrame):
166
+ return len(X)
167
+ if isinstance(X, np.ndarray):
168
+ return int(X.shape[0])
169
+ raise TypeError(
170
+ f"X must be a Polars DataFrame, Pandas DataFrame, or numpy array. Got {type(X).__name__}",
171
+ )
172
+
173
+ def _validate_data(
174
+ self,
175
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
176
+ y: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
177
+ groups: Union[pl.Series, pd.Series, "NDArray[Any]"] | None = None,
178
+ ) -> int:
179
+ """Validate input data and return the number of samples.
180
+
181
+ Parameters
182
+ ----------
183
+ X : polars.DataFrame, pandas.DataFrame, or numpy.ndarray
184
+ Training data.
185
+
186
+ y : polars.Series, pandas.Series, numpy.ndarray, or None
187
+ Target variable.
188
+
189
+ groups : polars.Series, pandas.Series, numpy.ndarray, or None
190
+ Group labels.
191
+
192
+ Returns:
193
+ -------
194
+ n_samples : int
195
+ The number of samples in the data.
196
+
197
+ Raises:
198
+ ------
199
+ ValueError
200
+ If the input data has inconsistent lengths.
201
+ TypeError
202
+ If the input data types are not supported.
203
+ """
204
+ n_samples = self._get_n_samples(X)
205
+
206
+ # Validate y if provided
207
+ if y is not None:
208
+ if isinstance(y, pl.Series | pd.Series):
209
+ n_y = len(y)
210
+ elif isinstance(y, np.ndarray):
211
+ n_y = y.shape[0]
212
+ else:
213
+ raise TypeError(
214
+ f"y must be a Polars Series, Pandas Series, or numpy array. Got {type(y).__name__}",
215
+ )
216
+
217
+ if n_y != n_samples:
218
+ raise ValueError(
219
+ f"X and y have inconsistent lengths: X has {n_samples} samples, y has {n_y} samples",
220
+ )
221
+
222
+ # Validate groups if provided
223
+ if groups is not None:
224
+ if isinstance(groups, pl.Series | pd.Series):
225
+ n_groups = len(groups)
226
+ elif isinstance(groups, np.ndarray):
227
+ n_groups = groups.shape[0]
228
+ else:
229
+ raise TypeError(
230
+ f"groups must be a Polars Series, Pandas Series, or numpy array. Got {type(groups).__name__}",
231
+ )
232
+
233
+ if n_groups != n_samples:
234
+ raise ValueError(
235
+ f"X and groups have inconsistent lengths: X has {n_samples} samples, groups has {n_groups} samples",
236
+ )
237
+
238
+ return n_samples
239
+
240
+ def _validate_session_alignment(
241
+ self,
242
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
243
+ align_to_sessions: bool,
244
+ session_col: str,
245
+ ) -> None:
246
+ """Validate that session column exists if session alignment is enabled.
247
+
248
+ Parameters
249
+ ----------
250
+ X : polars.DataFrame, pandas.DataFrame, or numpy.ndarray
251
+ Training data that may contain session column.
252
+
253
+ align_to_sessions : bool
254
+ Whether session alignment is requested.
255
+
256
+ session_col : str
257
+ Name of the session column to look for.
258
+
259
+ Raises
260
+ ------
261
+ ValueError
262
+ If align_to_sessions=True but session column is missing or X is not a DataFrame.
263
+
264
+ Notes
265
+ -----
266
+ This method provides helpful error messages that guide users to the qdata library
267
+ for session assignment if the required column is missing.
268
+ """
269
+ if not align_to_sessions:
270
+ return # Skip validation if not using sessions
271
+
272
+ # Check that X is a DataFrame (sessions require column access)
273
+ if not hasattr(X, "columns"):
274
+ raise ValueError(
275
+ f"align_to_sessions=True requires X to be a DataFrame "
276
+ f"(Polars or Pandas), got {type(X).__name__}.\n"
277
+ f"\n"
278
+ f"Session alignment works with tabular data that has a session "
279
+ f"identifier column. NumPy arrays do not support column names."
280
+ )
281
+
282
+ # Check for session column
283
+ columns = list(cast(Any, X.columns))
284
+ if session_col not in columns:
285
+ raise ValueError(
286
+ f"align_to_sessions=True requires '{session_col}' column in X, "
287
+ f"but it was not found.\n"
288
+ f"\n"
289
+ f"Available columns: {columns}\n"
290
+ f"\n"
291
+ f"To add session dates to your data using the qdata library:\n"
292
+ f"\n"
293
+ f"Option 1 - Using DataManager (recommended):\n"
294
+ f" from qdata import DataManager\n"
295
+ f" manager = DataManager()\n"
296
+ f" df = manager.load(\n"
297
+ f" symbol='BTC',\n"
298
+ f" exchange='CME',\n"
299
+ f" calendar='CME_Globex_Crypto'\n"
300
+ f" )\n"
301
+ f" # df now has '{session_col}' column automatically\n"
302
+ f"\n"
303
+ f"Option 2 - Using SessionAssigner directly:\n"
304
+ f" from ml4t.data.sessions import SessionAssigner\n"
305
+ f" assigner = SessionAssigner.from_exchange('CME')\n"
306
+ f" df_with_sessions = assigner.assign_sessions(df)\n"
307
+ f"\n"
308
+ f"Option 3 - If you have a different session column:\n"
309
+ f" cv = {self.__class__.__name__}(\n"
310
+ f" ...,\n"
311
+ f" align_to_sessions=True,\n"
312
+ f" session_col='your_column_name' # Specify your column\n"
313
+ f" )\n"
314
+ f"\n"
315
+ f"Option 4 - Disable session alignment:\n"
316
+ f" cv = {self.__class__.__name__}(\n"
317
+ f" ...,\n"
318
+ f" align_to_sessions=False # Use standard splitting\n"
319
+ f" )\n"
320
+ )
321
+
322
+ def _get_unique_sessions(
323
+ self,
324
+ X: pl.DataFrame | pd.DataFrame,
325
+ session_col: str,
326
+ ) -> pl.Series | pd.Series:
327
+ """Extract unique session identifiers in order of first appearance.
328
+
329
+ Parameters
330
+ ----------
331
+ X : polars.DataFrame or pandas.DataFrame
332
+ Data containing session column.
333
+
334
+ session_col : str
335
+ Name of the session column.
336
+
337
+ Returns
338
+ -------
339
+ sessions : polars.Series or pandas.Series
340
+ Unique session identifiers in order of first appearance.
341
+
342
+ Notes
343
+ -----
344
+ Sessions are returned in the order they first appear in the data, which
345
+ is the correct chronological order if the data is sorted by time (as it
346
+ should be for time-series cross-validation).
347
+
348
+ Previously this method sorted by session ID, which is incorrect when
349
+ session IDs are not naturally sortable in chronological order.
350
+ """
351
+ if isinstance(X, pl.DataFrame):
352
+ # maintain_order=True preserves order of first appearance
353
+ return X[session_col].unique(maintain_order=True)
354
+ else: # pandas DataFrame
355
+ # drop_duplicates without sorting preserves first appearance order
356
+ return X[session_col].drop_duplicates().reset_index(drop=True)
357
+
358
+ def _session_to_indices(
359
+ self,
360
+ X: pl.DataFrame | pd.DataFrame,
361
+ session_col: str,
362
+ ) -> tuple[list[Any], dict[Any, "NDArray[np.intp]"]]:
363
+ """Map each session to its row indices, preserving appearance order.
364
+
365
+ This is the key helper for session-aligned CV. It returns EXACT indices
366
+ per session, not (start, end) boundaries, which is critical for correct
367
+ behavior with non-contiguous or interleaved data.
368
+
369
+ Parameters
370
+ ----------
371
+ X : polars.DataFrame or pandas.DataFrame
372
+ Data containing session column.
373
+
374
+ session_col : str
375
+ Name of the session column.
376
+
377
+ Returns
378
+ -------
379
+ ordered_sessions : list
380
+ Session IDs in order of first appearance.
381
+
382
+ session_indices : dict
383
+ Mapping from session ID to numpy array of row indices (sorted).
384
+
385
+ Examples
386
+ --------
387
+ >>> # Data with interleaved assets
388
+ >>> X = pl.DataFrame({
389
+ ... "session": ["A", "A", "B", "A", "B"],
390
+ ... "asset": ["X", "Y", "X", "X", "Y"]
391
+ ... })
392
+ >>> sessions, indices = splitter._session_to_indices(X, "session")
393
+ >>> sessions
394
+ ['A', 'B']
395
+ >>> indices['A']
396
+ array([0, 1, 3]) # Exact indices, NOT range(0, 3)
397
+ >>> indices['B']
398
+ array([2, 4])
399
+ """
400
+ if isinstance(X, pl.DataFrame):
401
+ # Polars: use group_by with maintain_order=True
402
+ # Add row indices, group by session, collect indices per group
403
+ df_with_idx = X.with_row_index("__row_idx__")
404
+ grouped = df_with_idx.group_by(session_col, maintain_order=True).agg(
405
+ pl.col("__row_idx__")
406
+ )
407
+ ordered_sessions = grouped[session_col].to_list()
408
+ session_indices = {
409
+ row[session_col]: np.array(row["__row_idx__"], dtype=np.intp)
410
+ for row in grouped.iter_rows(named=True)
411
+ }
412
+ else:
413
+ # Pandas: use groupby().indices (fast, returns dict of arrays)
414
+ grouped = X.groupby(session_col, sort=False)
415
+ session_indices_raw = grouped.indices
416
+ # Preserve appearance order using drop_duplicates
417
+ ordered_sessions = X[session_col].drop_duplicates().tolist()
418
+ session_indices = {
419
+ session: np.array(session_indices_raw[session], dtype=np.intp)
420
+ for session in ordered_sessions
421
+ }
422
+
423
+ return ordered_sessions, session_indices
424
+
425
+ def _extract_timestamps(
426
+ self,
427
+ X: Union[pl.DataFrame, pd.DataFrame, "NDArray[Any]"],
428
+ timestamp_col: str | None = None,
429
+ ) -> pd.DatetimeIndex | None:
430
+ """Extract timestamps from data for time-based size calculations.
431
+
432
+ This method supports both Polars and pandas DataFrames, enabling
433
+ time-based test_size/train_size specifications (e.g., '4W', '3M').
434
+
435
+ Parameters
436
+ ----------
437
+ X : polars.DataFrame, pandas.DataFrame, or numpy.ndarray
438
+ Input data.
439
+ timestamp_col : str or None
440
+ Column name containing timestamps for Polars DataFrames.
441
+ If None, falls back to pandas DatetimeIndex (backward compatible).
442
+
443
+ Returns
444
+ -------
445
+ timestamps : pandas.DatetimeIndex or None
446
+ Timestamps as a pandas DatetimeIndex for time-based calculations.
447
+ Returns None if timestamps cannot be extracted.
448
+
449
+ Notes
450
+ -----
451
+ For Polars DataFrames:
452
+ - Requires timestamp_col to be specified
453
+ - Column must be datetime type
454
+ - Converts to pandas DatetimeIndex for compatibility with time parsing
455
+
456
+ For pandas DataFrames:
457
+ - Uses DatetimeIndex if available
458
+ - Falls back to timestamp_col if index is not datetime
459
+
460
+ For numpy arrays:
461
+ - Returns None (no timestamp information available)
462
+ """
463
+ # Polars DataFrame: extract from column
464
+ if isinstance(X, pl.DataFrame):
465
+ if timestamp_col is None:
466
+ return None
467
+ if timestamp_col not in X.columns:
468
+ raise ValueError(
469
+ f"timestamp_col='{timestamp_col}' not found in Polars DataFrame. "
470
+ f"Available columns: {X.columns}"
471
+ )
472
+ # Convert Polars datetime column to pandas DatetimeIndex
473
+ ts_series = X[timestamp_col].to_pandas()
474
+ if not pd.api.types.is_datetime64_any_dtype(ts_series):
475
+ raise ValueError(
476
+ f"timestamp_col='{timestamp_col}' must be datetime type, "
477
+ f"got {X[timestamp_col].dtype}"
478
+ )
479
+ idx = pd.DatetimeIndex(ts_series)
480
+ # Ensure timezone awareness (required for purging/embargo)
481
+ if idx.tz is None:
482
+ idx = idx.tz_localize("UTC")
483
+ return idx
484
+
485
+ # pandas DataFrame: prefer index, fallback to column
486
+ if isinstance(X, pd.DataFrame):
487
+ if isinstance(X.index, pd.DatetimeIndex):
488
+ return X.index
489
+ # Fallback: try timestamp_col if specified
490
+ if timestamp_col is not None and timestamp_col in X.columns:
491
+ ts_series = X[timestamp_col]
492
+ if pd.api.types.is_datetime64_any_dtype(ts_series):
493
+ return pd.DatetimeIndex(ts_series)
494
+ return None
495
+
496
+ # numpy array: no timestamp information
497
+ return None
498
+
499
+ def __repr__(self) -> str:
500
+ """Return a string representation of the splitter."""
501
+ return f"{self.__class__.__name__}()"